diff --git a/Colab_demo.ipynb b/Colab_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4db2c7efc0c1aed96750599223692f1424147905 --- /dev/null +++ b/Colab_demo.ipynb @@ -0,0 +1,127 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Untitled0.ipynb", + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FypCcZkNNt2p" + }, + "source": [ + "%cd /content\n", + "!git clone https://github.com/hzwer/Practical-RIFE" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "1wysVHxoN54f" + }, + "source": [ + "!gdown --id 1O5KfS3KzZCY3imeCr2LCsntLhutKuAqj\n", + "!7z e Practical-RIFE/RIFE_trained_model_v3.8.zip" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "AhbHfRBJRAUt" + }, + "source": [ + "!mkdir /content/Practical-RIFE/train_log\n", + "!mv *.py /content/Practical-RIFE/train_log/\n", + "!mv *.pkl /content/Practical-RIFE/train_log/\n", + "%cd /content/Practical-RIFE/\n", + "!gdown --id 1i3xlKb7ax7Y70khcTcuePi6E7crO_dFc\n", + "!pip3 install -r requirements.txt" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rirngW5uRMdg" + }, + "source": [ + "Please upload your video to content/Practical-RIFE/video.mp4, or use our demo video." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dnLn4aHHPzN3" + }, + "source": [ + "!nvidia-smi\n", + "!python3 inference_video.py --exp=1 --video=demo.mp4 --montage --skip" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "77KK6lxHgJhf" + }, + "source": [ + "Our demo.mp4 is 25FPS. You can adjust the parameters for your own perference.\n", + "For example: \n", + "--fps=60 --exp=1 --video=mydemo.avi --png" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0zIBbVE3UfUD", + "cellView": "code" + }, + "source": [ + "from IPython.display import display, Image\n", + "import moviepy.editor as mpy\n", + "display(mpy.ipython_display('demo_4X_100fps.mp4', height=256, max_duration=100.))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tWkJCNgP3zXA" + }, + "source": [ + "!python3 inference_img.py --img demo/I0_0.png demo/I0_1.png\n", + "ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf \"split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1\" output/slomo.gif\n", + "# Image interpolation" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52 --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/Practical-RIFE b/Practical-RIFE deleted file mode 160000 index f3e48ceb02e4c21bc8868b03994b98f3402ffb3d..0000000000000000000000000000000000000000 --- a/Practical-RIFE +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f3e48ceb02e4c21bc8868b03994b98f3402ffb3d diff --git a/Practical-RIFE/Colab_demo.ipynb b/Practical-RIFE/Colab_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4db2c7efc0c1aed96750599223692f1424147905 --- /dev/null +++ b/Practical-RIFE/Colab_demo.ipynb @@ -0,0 +1,127 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Untitled0.ipynb", + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FypCcZkNNt2p" + }, + "source": [ + "%cd /content\n", + "!git clone https://github.com/hzwer/Practical-RIFE" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "1wysVHxoN54f" + }, + "source": [ + "!gdown --id 1O5KfS3KzZCY3imeCr2LCsntLhutKuAqj\n", + "!7z e Practical-RIFE/RIFE_trained_model_v3.8.zip" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "AhbHfRBJRAUt" + }, + "source": [ + "!mkdir /content/Practical-RIFE/train_log\n", + "!mv *.py /content/Practical-RIFE/train_log/\n", + "!mv *.pkl /content/Practical-RIFE/train_log/\n", + "%cd /content/Practical-RIFE/\n", + "!gdown --id 1i3xlKb7ax7Y70khcTcuePi6E7crO_dFc\n", + "!pip3 install -r requirements.txt" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rirngW5uRMdg" + }, + "source": [ + "Please upload your video to content/Practical-RIFE/video.mp4, or use our demo video." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dnLn4aHHPzN3" + }, + "source": [ + "!nvidia-smi\n", + "!python3 inference_video.py --exp=1 --video=demo.mp4 --montage --skip" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "77KK6lxHgJhf" + }, + "source": [ + "Our demo.mp4 is 25FPS. You can adjust the parameters for your own perference.\n", + "For example: \n", + "--fps=60 --exp=1 --video=mydemo.avi --png" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0zIBbVE3UfUD", + "cellView": "code" + }, + "source": [ + "from IPython.display import display, Image\n", + "import moviepy.editor as mpy\n", + "display(mpy.ipython_display('demo_4X_100fps.mp4', height=256, max_duration=100.))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tWkJCNgP3zXA" + }, + "source": [ + "!python3 inference_img.py --img demo/I0_0.png demo/I0_1.png\n", + "ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf \"split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1\" output/slomo.gif\n", + "# Image interpolation" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/Practical-RIFE/inference_img.py b/Practical-RIFE/inference_img.py new file mode 100644 index 0000000000000000000000000000000000000000..cee947ed8a15fe782cf8097ecde0a467eb1e55a3 --- /dev/null +++ b/Practical-RIFE/inference_img.py @@ -0,0 +1,118 @@ +import os +import cv2 +import torch +import argparse +from torch.nn import functional as F +import warnings +warnings.filterwarnings("ignore") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='Interpolation for a pair of images') +parser.add_argument('--img', dest='img', nargs=2, required=True) +parser.add_argument('--exp', default=4, type=int) +parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range') +parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold') +parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles') +parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files') + +args = parser.parse_args() + +try: + try: + from model.RIFE_HDv2 import Model + model = Model() + model.load_model(args.modelDir, -1) + print("Loaded v2.x HD model.") + except: + from train_log.RIFE_HDv3 import Model + model = Model() + model.load_model(args.modelDir, -1) + print("Loaded v3.x HD model.") +except: + from model.RIFE_HD import Model + model = Model() + model.load_model(args.modelDir, -1) + print("Loaded v1.x HD model") +if not hasattr(model, 'version'): + model.version = 0 +model.eval() +model.device() + +if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) + +else: + img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED) + img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED) + img0 = cv2.resize(img0, (448, 256)) + img1 = cv2.resize(img1, (448, 256)) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + +n, c, h, w = img0.shape +ph = ((h - 1) // 64 + 1) * 64 +pw = ((w - 1) // 64 + 1) * 64 +padding = (0, pw - w, 0, ph - h) +img0 = F.pad(img0, padding) +img1 = F.pad(img1, padding) + + +if args.ratio: + if model.version >= 3.9: + img_list = [img0, model.inference(img0, img1, args.ratio), img1] + else: + img0_ratio = 0.0 + img1_ratio = 1.0 + if args.ratio <= img0_ratio + args.rthreshold / 2: + middle = img0 + elif args.ratio >= img1_ratio - args.rthreshold / 2: + middle = img1 + else: + tmp_img0 = img0 + tmp_img1 = img1 + for inference_cycle in range(args.rmaxcycles): + middle = model.inference(tmp_img0, tmp_img1) + middle_ratio = ( img0_ratio + img1_ratio ) / 2 + if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2): + break + if args.ratio > middle_ratio: + tmp_img0 = middle + img0_ratio = middle_ratio + else: + tmp_img1 = middle + img1_ratio = middle_ratio + img_list.append(middle) + img_list.append(img1) +else: + if model.version >= 3.9: + img_list = [img0] + n = 2 ** args.exp + for i in range(n-1): + img_list.append(model.inference(img0, img1, (i+1) * 1. / n)) + img_list.append(img1) + else: + img_list = [img0, img1] + for i in range(args.exp): + tmp = [] + for j in range(len(img_list) - 1): + mid = model.inference(img_list[j], img_list[j + 1]) + tmp.append(img_list[j]) + tmp.append(mid) + tmp.append(img1) + img_list = tmp + +if not os.path.exists('output'): + os.mkdir('output') +for i in range(len(img_list)): + if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + else: + cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) diff --git a/Practical-RIFE/inference_img_SR.py b/Practical-RIFE/inference_img_SR.py new file mode 100644 index 0000000000000000000000000000000000000000..4ecf2acd4e0becdf65bfb57e01f09aa5bd0594c9 --- /dev/null +++ b/Practical-RIFE/inference_img_SR.py @@ -0,0 +1,69 @@ +import os +import cv2 +import torch +import argparse +from torch.nn import functional as F +import warnings +warnings.filterwarnings("ignore") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='STVSR for a pair of images') +parser.add_argument('--img', dest='img', nargs=2, required=True) +parser.add_argument('--exp', default=2, type=int) +parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range') +parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files') + +args = parser.parse_args() + +from train_log.model import Model +model = Model() +model.device() +model.load_model('train_log') +model.eval() + +if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) +else: + img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED) + img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED) + img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + +n, c, h, w = img0.shape +ph = ((h - 1) // 32 + 1) * 32 +pw = ((w - 1) // 32 + 1) * 32 +padding = (0, pw - w, 0, ph - h) +img0 = F.pad(img0, padding) +img1 = F.pad(img1, padding) + +if args.ratio: + print('ratio={}'.format(args.ratio)) + img_list = model.inference(img0, img1, timestep=args.ratio) +else: + n = 2 ** args.exp - 1 + time_list = [0] + for i in range(n): + time_list.append((i+1) * 1. / (n+1)) + time_list.append(1) + print(time_list) + img_list = model.inference(img0, img1, timestep=time_list) + +if not os.path.exists('output'): + os.mkdir('output') +for i in range(len(img_list)): + if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + else: + cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) diff --git a/Practical-RIFE/inference_video.py b/Practical-RIFE/inference_video.py new file mode 100644 index 0000000000000000000000000000000000000000..854eef45c57cc6f7027346c610997c643d1e2113 --- /dev/null +++ b/Practical-RIFE/inference_video.py @@ -0,0 +1,293 @@ +import os +import cv2 +import torch +import argparse +import numpy as np +from tqdm import tqdm +from torch.nn import functional as F +import warnings +import _thread +import skvideo.io +from queue import Queue, Empty +from model.pytorch_msssim import ssim_matlab + +warnings.filterwarnings("ignore") + +def transferAudio(sourceVideo, targetVideo): + import shutil + import moviepy.editor + tempAudioFileName = "./temp/audio.mkv" + + # split audio from original video file and store in "temp" directory + if True: + + # clear old "temp" directory if it exits + if os.path.isdir("temp"): + # remove temp directory + shutil.rmtree("temp") + # create new "temp" directory + os.makedirs("temp") + # extract audio from video + os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName)) + + targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] + os.rename(targetVideo, targetNoAudio) + # combine audio file and new video file + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + + if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac + tempAudioFileName = "./temp/audio.m4a" + os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName)) + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format + os.rename(targetNoAudio, targetVideo) + print("Audio transfer failed. Interpolated video will have no audio") + else: + print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") + + # remove audio-less video + os.remove(targetNoAudio) + else: + os.remove(targetNoAudio) + + # remove temp directory + shutil.rmtree("temp") + +parser = argparse.ArgumentParser(description='Interpolation for a pair of images') +parser.add_argument('--video', dest='video', type=str, default=None) +parser.add_argument('--output', dest='output', type=str, default=None) +parser.add_argument('--img', dest='img', type=str, default=None) +parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video') +parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files') +parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores') +parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video') +parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video') +parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') +parser.add_argument('--fps', dest='fps', type=int, default=None) +parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs') +parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension') +parser.add_argument('--exp', dest='exp', type=int, default=1) +parser.add_argument('--multi', dest='multi', type=int, default=2) + +args = parser.parse_args() +if args.exp != 1: + args.multi = (2 ** args.exp) +assert (not args.video is None or not args.img is None) +if args.skip: + print("skip flag is abandoned, please refer to issue #207.") +if args.UHD and args.scale==1.0: + args.scale = 0.5 +assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0] +if not args.img is None: + args.png = True + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + if(args.fp16): + torch.set_default_tensor_type(torch.cuda.HalfTensor) + +try: + from train_log.RIFE_HDv3 import Model +except: + print("Please download our model from model list") +model = Model() +if not hasattr(model, 'version'): + model.version = 0 +model.load_model(args.modelDir, -1) +print("Loaded 3.x/4.x HD model.") +model.eval() +model.device() + +if not args.video is None: + videoCapture = cv2.VideoCapture(args.video) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + videoCapture.release() + if args.fps is None: + fpsNotAssigned = True + args.fps = fps * args.multi + else: + fpsNotAssigned = False + videogen = skvideo.io.vreader(args.video) + lastframe = next(videogen) + fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') + video_path_wo_ext, ext = os.path.splitext(args.video) + print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps)) + if args.png == False and fpsNotAssigned == True: + print("The audio will be merged after interpolation process") + else: + print("Will not merge audio because using png or fps flag!") +else: + videogen = [] + for f in os.listdir(args.img): + if 'png' in f: + videogen.append(f) + tot_frame = len(videogen) + videogen.sort(key= lambda x:int(x[:-4])) + lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + videogen = videogen[1:] +h, w, _ = lastframe.shape +vid_out_name = None +vid_out = None +if args.png: + if not os.path.exists('vid_out'): + os.mkdir('vid_out') +else: + if args.output is not None: + print("Out") + vid_out_name = args.output + else: + vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext) + print("Width is ", w," and height is ", h) + vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h)) + +def clear_write_buffer(user_args, write_buffer): + cnt = 0 + while True: + item = write_buffer.get() + if item is None: + break + if user_args.png: + cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1]) + cnt += 1 + else: + vid_out.write(item[:, :, ::-1]) + +def build_read_buffer(user_args, read_buffer, videogen): + try: + for frame in videogen: + if not user_args.img is None: + frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + if user_args.montage: + frame = frame[:, left: left + w] + read_buffer.put(frame) + except: + pass + read_buffer.put(None) + +def make_inference(I0, I1, n): + global model + if model.version >= 3.9: + res = [] + for i in range(n): + res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale)) + return res + else: + middle = model.inference(I0, I1, args.scale) + if n == 1: + return [middle] + first_half = make_inference(I0, middle, n=n//2) + second_half = make_inference(middle, I1, n=n//2) + if n%2: + return [*first_half, middle, *second_half] + else: + return [*first_half, *second_half] + +def pad_image(img): + if(args.fp16): + return F.pad(img, padding).half() + else: + return F.pad(img, padding) + +if args.montage: + left = w // 4 + w = w // 2 +tmp = max(128, int(128 / args.scale)) +ph = ((h - 1) // tmp + 1) * tmp +pw = ((w - 1) // tmp + 1) * tmp +padding = (0, pw - w, 0, ph - h) +pbar = tqdm(total=tot_frame) +if args.montage: + lastframe = lastframe[:, left: left + w] +write_buffer = Queue(maxsize=500) +read_buffer = Queue(maxsize=500) +_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen)) +_thread.start_new_thread(clear_write_buffer, (args, write_buffer)) + +I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. +I1 = pad_image(I1) +temp = None # save lastframe when processing static frame + +while True: + if temp is not None: + frame = temp + temp = None + else: + frame = read_buffer.get() + if frame is None: + break + I0 = I1 + I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. + I1 = pad_image(I1) + I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + + break_flag = False + if ssim > 0.996: + frame = read_buffer.get() # read a new frame + if frame is None: + break_flag = True + frame = lastframe + else: + temp = frame + I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. + I1 = pad_image(I1) + I1 = model.inference(I0, I1, args.scale) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w] + + if ssim < 0.2: + output = [] + for i in range(args.multi - 1): + output.append(I0) + ''' + output = [] + step = 1 / args.multi + alpha = 0 + for i in range(args.multi - 1): + alpha += step + beta = 1-alpha + output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) + ''' + else: + output = make_inference(I0, I1, args.multi-1) + + if args.montage: + write_buffer.put(np.concatenate((lastframe, lastframe), 1)) + for mid in output: + mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) + write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1)) + else: + write_buffer.put(lastframe) + for mid in output: + mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) + write_buffer.put(mid[:h, :w]) + pbar.update(1) + lastframe = frame + if break_flag: + break + +if args.montage: + write_buffer.put(np.concatenate((lastframe, lastframe), 1)) +else: + write_buffer.put(lastframe) +import time +while(not write_buffer.empty()): + time.sleep(0.1) +pbar.close() +if not vid_out is None: + vid_out.release() + +# move audio to new video file if appropriate +# if args.png == False and fpsNotAssigned == True and not args.video is None: +# try: +# transferAudio(args.video, vid_out_name) +# except: +# print("Audio transfer failed. Interpolated video will have no audio") +# targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1] +# os.rename(targetNoAudio, vid_out_name) diff --git a/Practical-RIFE/inference_video_enhance.py b/Practical-RIFE/inference_video_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..d3076cd233fc0168d54c2a4b57393473fec6d5a5 --- /dev/null +++ b/Practical-RIFE/inference_video_enhance.py @@ -0,0 +1,201 @@ +import os +import cv2 +import torch +import argparse +import numpy as np +from tqdm import tqdm +from torch.nn import functional as F +import warnings +import _thread +import skvideo.io +from queue import Queue, Empty +from model.pytorch_msssim import ssim_matlab + +warnings.filterwarnings("ignore") + +def transferAudio(sourceVideo, targetVideo): + import shutil + import moviepy.editor + tempAudioFileName = "./temp/audio.mkv" + + # split audio from original video file and store in "temp" directory + if True: + + # clear old "temp" directory if it exits + if os.path.isdir("temp"): + # remove temp directory + shutil.rmtree("temp") + # create new "temp" directory + os.makedirs("temp") + # extract audio from video + os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName)) + + targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] + os.rename(targetVideo, targetNoAudio) + # combine audio file and new video file + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + + if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac + tempAudioFileName = "./temp/audio.m4a" + os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName)) + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format + os.rename(targetNoAudio, targetVideo) + print("Audio transfer failed. Interpolated video will have no audio") + else: + print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") + + # remove audio-less video + os.remove(targetNoAudio) + else: + os.remove(targetNoAudio) + + # remove temp directory + shutil.rmtree("temp") + +parser = argparse.ArgumentParser(description='Video SR') +parser.add_argument('--video', dest='video', type=str, default=None) +parser.add_argument('--output', dest='output', type=str, default=None) +parser.add_argument('--img', dest='img', type=str, default=None) +parser.add_argument('--model', dest='modelDir', type=str, default='train_log_SAFA', help='directory with trained model files') +parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores') +parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs') +parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension') + +args = parser.parse_args() +assert (not args.video is None or not args.img is None) +if not args.img is None: + args.png = True + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + if(args.fp16): + print('set fp16') + torch.set_default_tensor_type(torch.cuda.HalfTensor) + +try: + from train_log_SAFA.model import Model +except: + print("Please download our model from model list") +model = Model() +model.device() +model.load_model(args.modelDir) +print("Loaded SAFA model.") +model.eval() + +if not args.video is None: + videoCapture = cv2.VideoCapture(args.video) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + videoCapture.release() + fpsNotAssigned = True + videogen = skvideo.io.vreader(args.video) + lastframe = next(videogen) + fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') + video_path_wo_ext, ext = os.path.splitext(args.video) + if args.png == False and fpsNotAssigned == True: + print("The audio will be merged after interpolation process") + else: + print("Will not merge audio because using png or fps flag!") +else: + videogen = [] + for f in os.listdir(args.img): + if 'png' in f: + videogen.append(f) + tot_frame = len(videogen) + videogen.sort(key= lambda x:int(x[:-4])) + lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + videogen = videogen[1:] + +h, w, _ = lastframe.shape + +vid_out_name = None +vid_out = None +if args.png: + if not os.path.exists('vid_out'): + os.mkdir('vid_out') +else: + if args.output is not None: + vid_out_name = args.output + else: + vid_out_name = '{}_2X{}'.format(video_path_wo_ext, ext) + vid_out = cv2.VideoWriter(vid_out_name, fourcc, fps, (w, h)) + +def clear_write_buffer(user_args, write_buffer): + cnt = 0 + while True: + item = write_buffer.get() + if item is None: + break + if user_args.png: + cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1]) + cnt += 1 + else: + vid_out.write(item[:, :, ::-1]) + +def build_read_buffer(user_args, read_buffer, videogen): + for frame in videogen: + if not user_args.img is None: + frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + # if user_args.montage: + # frame = frame[:, left: left + w] + read_buffer.put(frame) + read_buffer.put(None) + +def pad_image(img): + if(args.fp16): + return F.pad(img, padding, mode='reflect').half() + else: + return F.pad(img, padding, mode='reflect') + +tmp = 64 +ph = ((h - 1) // tmp + 1) * tmp +pw = ((w - 1) // tmp + 1) * tmp +padding = (0, pw - w, 0, ph - h) +pbar = tqdm(total=tot_frame) +write_buffer = Queue(maxsize=500) +read_buffer = Queue(maxsize=500) +_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen)) +_thread.start_new_thread(clear_write_buffer, (args, write_buffer)) + +while True: + frame = read_buffer.get() + if frame is None: + break + # lastframe_2x = cv2.resize(lastframe, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + # frame_2x = cv2.resize(frame, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + I0 = pad_image(torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) + I1 = pad_image(torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) + I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + if ssim < 0.2: + out = [model.inference(I0, I0, [0])[0], model.inference(I1, I1, [0])[0]] + else: + out = model.inference(I0, I1, [0, 1]) + assert(len(out) == 2) + write_buffer.put((out[0][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) + write_buffer.put((out[1][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) + lastframe = read_buffer.get() + if lastframe is None: + break + pbar.update(2) + +import time +while(not write_buffer.empty()): + time.sleep(0.1) +pbar.close() +if not vid_out is None: + vid_out.release() + +# move audio to new video file if appropriate +if args.png == False and fpsNotAssigned == True and not args.video is None: + try: + transferAudio(args.video, vid_out_name) + except: + print("Audio transfer failed. Interpolated video will have no audio") + targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1] + os.rename(targetNoAudio, vid_out_name) diff --git a/Practical-RIFE/model/__pycache__/loss.cpython-310.pyc b/Practical-RIFE/model/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe05e6558d719f6ef013df2130686d2fae3211ae Binary files /dev/null and b/Practical-RIFE/model/__pycache__/loss.cpython-310.pyc differ diff --git a/Practical-RIFE/model/__pycache__/warplayer.cpython-310.pyc b/Practical-RIFE/model/__pycache__/warplayer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91379f162fff78c380af33f0b43739e2daaa8ea2 Binary files /dev/null and b/Practical-RIFE/model/__pycache__/warplayer.cpython-310.pyc differ diff --git a/Practical-RIFE/model/loss.py b/Practical-RIFE/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..72e5de6af050df7d55c2871a69637077970ddfb9 --- /dev/null +++ b/Practical-RIFE/model/loss.py @@ -0,0 +1,128 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class EPE(nn.Module): + def __init__(self): + super(EPE, self).__init__() + + def forward(self, flow, gt, loss_mask): + loss_map = (flow - gt.detach()) ** 2 + loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 + return (loss_map * loss_mask) + + +class Ternary(nn.Module): + def __init__(self): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape( + (patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf**2) + return transf_norm + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + def forward(self, img0, img1): + img0 = self.transform(self.rgb2gray(img0)) + img1 = self.transform(self.rgb2gray(img1)) + return self.hamming(img0, img1) * self.valid_mask(img0, 1) + + +class SOBEL(nn.Module): + def __init__(self): + super(SOBEL, self).__init__() + self.kernelX = torch.tensor([ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1], + ]).float() + self.kernelY = self.kernelX.clone().T + self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) + self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) + + def forward(self, pred, gt): + N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] + img_stack = torch.cat( + [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0) + sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) + sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) + pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:] + pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:] + + L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y) + loss = (L1X+L1Y) + return loss + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std) + self.weight.data = torch.eye(c).view(c, c, 1, 1) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean) + self.requires_grad = False + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, rank=0): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + pretrained = True + self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X, Y, indices=None): + X = self.normalize(X) + Y = self.normalize(Y) + indices = [2, 7, 12, 21, 30] + weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] + k = 0 + loss = 0 + for i in range(indices[-1]): + X = self.vgg_pretrained_features[i](X) + Y = self.vgg_pretrained_features[i](Y) + if (i+1) in indices: + loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 + k += 1 + return loss + +if __name__ == '__main__': + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal( + 0, 1, (3, 3, 256, 256))).float().to(device) + ternary_loss = Ternary() + print(ternary_loss(img0, img1).shape) diff --git a/Practical-RIFE/model/pytorch_msssim/__init__.py b/Practical-RIFE/model/pytorch_msssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d30326188cf6afacf2fc84c7ae18efe14dae2e --- /dev/null +++ b/Practical-RIFE/model/pytorch_msssim/__init__.py @@ -0,0 +1,200 @@ +import torch +import torch.nn.functional as F +from math import exp +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + +def create_window_3d(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) + return window + + +def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) + mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq + sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window_3d(real_size, channel=1).to(img1.device) + # Channel is set to 1 since we consider color images as volumetric images + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) + mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq + sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq + sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs ** weights + pow2 = mssim ** weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +# Classes to re-use window +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 3 channel for SSIM + self.channel = 3 + self.window = create_window(window_size, channel=self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) + self.window = window + self.channel = channel + + _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + dssim = (1 - _ssim) / 2 + return dssim + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super(MSSSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) diff --git a/Practical-RIFE/model/pytorch_msssim/__pycache__/__init__.cpython-310.pyc b/Practical-RIFE/model/pytorch_msssim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce9bdfe8f7613cf8969b4b57516f89709b4dc59a Binary files /dev/null and b/Practical-RIFE/model/pytorch_msssim/__pycache__/__init__.cpython-310.pyc differ diff --git a/Practical-RIFE/model/warplayer.py b/Practical-RIFE/model/warplayer.py new file mode 100644 index 0000000000000000000000000000000000000000..21b0b904cf71b297fd43813134c57d13a3ae9e4a --- /dev/null +++ b/Practical-RIFE/model/warplayer.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) diff --git a/Practical-RIFE/train_log/.DS_Store b/Practical-RIFE/train_log/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..198104583e60199ba4bf4eb0d324c0daadae3363 Binary files /dev/null and b/Practical-RIFE/train_log/.DS_Store differ diff --git a/Practical-RIFE/train_log/IFNet_HDv3.py b/Practical-RIFE/train_log/IFNet_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e4cf8e196cbcf61527e5d710b8555e712caa49 --- /dev/null +++ b/Practical-RIFE/train_log/IFNet_HDv3.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.warplayer import warp +# from train_log.refine import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True) + ) + +class Head(nn.Module): + def __init__(self): + super(Head, self).__init__() + self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) + self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) + self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x, feat=False): + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + if feat: + return [x0, x1, x2, x3] + return x3 + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\ +) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c//2, 3, 2, 1), + conv(c//2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4*6, 4, 2, 1), + nn.PixelShuffle(2) + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + return flow, mask + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7+16, c=192) + self.block1 = IFBlock(8+4+16, c=128) + self.block2 = IFBlock(8+4+16, c=96) + self.block3 = IFBlock(8+4+16, c=64) + self.encode = Head() + # self.contextnet = Contextnet() + # self.unet = Unet() + + def forward(self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + loss_cons = 0 + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i]) + if ensemble: + f_, m_ = block[i](torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1-timestep), 1), None, scale=scale_list[i]) + flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (mask + (-m_)) / 2 + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask), 1), flow, scale=scale_list[i]) + if ensemble: + f_, m_ = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], wf1, wf0, 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask = torch.sigmoid(mask) + merged[3] = (warped_img0 * mask + warped_img1 * (1 - mask)) + if not fastmode: + print('contextnet is removed') + ''' + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[3] = torch.clamp(merged[3] + res, 0, 1) + ''' + return flow_list, mask_list[3], merged diff --git a/Practical-RIFE/train_log/RIFE_HDv3.py b/Practical-RIFE/train_log/RIFE_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..897c1cc6468919fd08e11b135f446d8fa9ff7a37 --- /dev/null +++ b/Practical-RIFE/train_log/RIFE_HDv3.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from train_log.IFNet_HDv3 import * +import torch.nn.functional as F +from model.loss import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.device() + self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.epe = EPE() + self.version = 4.8 + # self.vgg = VGGPerceptualLoss().to(device) + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank=0): + def convert(param): + if rank == -1: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return param + if rank <= 0: + if torch.cuda.is_available(): + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))), False) + else: + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')), False) + + def save_model(self, path, rank=0): + if rank == 0: + torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) + + def inference(self, img0, img1, timestep=0.5, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [8/scale, 4/scale, 2/scale, 1/scale] + flow, mask, merged = self.flownet(imgs, timestep, scale_list) + return merged[3] + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group['lr'] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + scale = [8, 4, 2, 1] + flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + loss_l1 = (merged[3] - gt).abs().mean() + loss_smooth = self.sobel(flow[3], flow[3]*0).mean() + # loss_vgg = self.vgg(merged[2], gt) + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_cons + loss_smooth * 0.1 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[3], { + 'mask': mask, + 'flow': flow[3][:, :2], + 'loss_l1': loss_l1, + 'loss_cons': loss_cons, + 'loss_smooth': loss_smooth, + } diff --git a/Practical-RIFE/train_log/__pycache__/IFNet_HDv3.cpython-310.pyc b/Practical-RIFE/train_log/__pycache__/IFNet_HDv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d69ff6e1283580275df326bdb487fa571113ae64 Binary files /dev/null and b/Practical-RIFE/train_log/__pycache__/IFNet_HDv3.cpython-310.pyc differ diff --git a/Practical-RIFE/train_log/__pycache__/RIFE_HDv3.cpython-310.pyc b/Practical-RIFE/train_log/__pycache__/RIFE_HDv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9324decdc7150c3da99274f3ca695c9ac893aa61 Binary files /dev/null and b/Practical-RIFE/train_log/__pycache__/RIFE_HDv3.cpython-310.pyc differ diff --git a/Practical-RIFE/train_log/flownet.pkl b/Practical-RIFE/train_log/flownet.pkl new file mode 100644 index 0000000000000000000000000000000000000000..aa218a2b4b78f404cdb9d66aebed32113a8f7b6c --- /dev/null +++ b/Practical-RIFE/train_log/flownet.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1ee3186270312a38316e4d53c77b31a60062cfa5636e13d6f0a1dd89bb7b128 +size 21508207 diff --git a/Practical-RIFE/train_log/refine.py b/Practical-RIFE/train_log/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..41b648ec12403f442f8bf0941bed9b0d896f2d87 --- /dev/null +++ b/Practical-RIFE/train_log/refine.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.nn.functional as F + +device = torch.device("cuda") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.LeakyReLU(0.2, True) + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +c = 16 +class Contextnet(nn.Module): + def __init__(self): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2*c) + self.conv3 = Conv2(2*c, 4*c) + self.conv4 = Conv2(4*c, 8*c) + + def forward(self, x, flow): + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + +class Unet(nn.Module): + def __init__(self): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.Conv2d(c, 3, 3, 1, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/README.md b/README.md index 38fad76c49de4b6729d4b4ce84aa2802ce6db590..d49580a79fd546e5cd9f09583802dd5ce27ac26c 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,156 @@ ---- -license: apache-2.0 ---- +# roop-unleashed -Download Moore-AnimateAnyone weights by: +[Changelog](#changelog) โ€ข [Usage](#usage) โ€ข [Wiki](https://github.com/C0untFloyd/roop-unleashed/wiki) + + +Uncensored Deepfakes for images and videos without training and an easy-to-use GUI. + + +![Screen](https://github.com/C0untFloyd/roop-unleashed/assets/131583554/6ee6860d-efbe-4337-8c62-a67598863637) + +### Features + +- Platform-independant Browser GUI +- Selection of multiple input/output faces in one go +- Many different swapping modes, first detected, face selections, by gender +- Batch processing of images/videos +- Masking of face occluders using text prompts or automatically +- Optional Face Upscaler/Restoration using different enhancers +- Preview swapping from different video frames +- Live Fake Cam using your webcam +- Extras Tab for cutting videos etc. +- Settings - storing configuration for next session +- Theme Support + +and lots more... + + +## Disclaimer + +This project is for technical and academic use only. +Users of this software are expected to use this software responsibly while abiding the local law. If a face of a real person is being used, users are suggested to get consent from the concerned person and clearly mention that it is a deepfake when posting content online. Developers of this software will not be responsible for actions of end-users. +**Please do not apply it to illegal and unethical scenarios.** + +In the event of violation of the legal and ethical requirements of the user's country or region, this code repository is exempt from liability + +### Installation + +Please refer to the [wiki](https://github.com/C0untFloyd/roop-unleashed/wiki). + + + + +### Usage + +- Windows: run the `windows_run.bat` from the Installer. +- Linux: `python run.py` + + + Open In Colab + + + +Additional commandline arguments are currently unsupported and settings should be done via the UI. + +> Note: When you run this program for the first time, it will download some models roughly ~2Gb in size. + + + + +### Changelog + +**22.04.2024** v3.9.0 + +- Bugfix: Face detection bounding box corrupt values at weird angles +- Rewrote mask previewing to work with every model +- Switching mask engines toggles text interactivity +- Clearing target files, resets face selection dropdown +- Massive rewrite of swapping architecture, needed for xseg implementation +- Added DFL Xseg Support for partial face occlusion +- Face masking only runs when there is a face detected +- Removed unnecessary toggle checkbox for text masking + + +**22.03.2024** v3.6.5 + +- Bugfix: Installer pulling latest update on first installation +- Bugfix: Regression issue, blurring/erosion missing from face swap +- Exposed erosion and blur amounts to UI +- Using same values for manual masking too + + +**20.03.2024** v3.6.3 + +- Bugfix: Workaround for Gradio Slider Change Bug +- Bugfix: CSS Styling to fix Gradio Image Height Bug +- Made face swapping mask offsets resolution independant +- Show offset mask as overlay +- Changed layout for masking + + +**18.03.2024** v3.6.0 + +- Updated to Gradio 4.21.0 - requiring many changes under the hood +- New manual masking (draw the mask yourself) +- Extras Tab, streamlined cutting/joining videos +- Re-added face selection by gender (on-demand loading, default turned off) +- Removed unnecessary activate live-cam option +- Added time info to preview frame and changed frame slider event to allow faster changes + + +**10.03.2024** v3.5.5 + +- Bugfix: Installer Path Env +- Bugfix: file attributes +- Video processing checks for presence of ffmpeg and displays warning if not found +- Removed gender + age detection to speed up processing. Option removed from UI +- Replaced restoreformer with restoreformer++ +- Live Cam recoded to run separate from virtual cam and without blocking controls +- Swapping with only 1 target face allows selecting from several input faces + + + +**08.01.2024** v3.5.0 + +- Bugfix: wrong access options when creating folders +- New auto rotation of horizontal faces, fixing bad landmark positions (expanded on ![PR 364](https://github.com/C0untFloyd/roop-unleashed/pull/364)) +- Simple VR Option for stereo Images/Movies, best used in selected face mode +- Added RestoreFormer Enhancer - https://github.com/wzhouxiff/RestoreFormer +- Bumped up package versions for onnx/Torch etc. + + +**16.10.2023** v3.3.4 + +**11.8.2023** v2.7.0 + +Initial Gradio Version - old TkInter Version now deprecated + +- Re-added unified padding to face enhancers +- Fixed DMDNet for all resolutions +- Selecting target face now automatically switches swapping mode to selected +- GPU providers are correctly set using the GUI (needs restart currently) +- Local output folder can be opened from page +- Unfinished extras functions disabled for now +- Installer checks out specific commit, allowing to go back to first install +- Updated readme for new gradio version +- Updated Colab + + +# Acknowledgements + +Lots of ideas, code or pre-trained models borrowed from the following projects: + +https://github.com/deepinsight/insightface
+https://github.com/s0md3v/roop
+https://github.com/AUTOMATIC1111/stable-diffusion-webui
+https://github.com/Hillobar/Rope
+https://github.com/TencentARC/GFPGAN
+https://github.com/kadirnar/codeformer-pip
+https://github.com/csxmli2016/DMDNet
+https://github.com/glucauze/sd-webui-faceswaplab
+https://github.com/ykk648/face_power
+ +
+
+Thanks to all developers! -```bash -git lfs install -git clone https://huggingface.co/patrolli/AnimateAnyone -``` \ No newline at end of file diff --git a/__pycache__/handler.cpython-310.pyc b/__pycache__/handler.cpython-310.pyc index 457243ca4690932d3e2f86b89887a39b3f81edec..d6197aae289158b56a4dc88a8f55d490dadcadc5 100644 Binary files a/__pycache__/handler.cpython-310.pyc and b/__pycache__/handler.cpython-310.pyc differ diff --git a/__pycache__/settings.cpython-310.pyc b/__pycache__/settings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..859cc39fb90e23f990991b404857cde0f968c745 Binary files /dev/null and b/__pycache__/settings.cpython-310.pyc differ diff --git a/clip/__init__.py b/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/clip/bpe_simple_vocab_16e6.txt.gz b/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/clip/clip.py b/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..f983b7b35a19634bfc941733ab24d69b132ebeac --- /dev/null +++ b/clip/clip.py @@ -0,0 +1,241 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + #if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + # result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + #else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/clip/clipseg.py b/clip/clipseg.py new file mode 100644 index 0000000000000000000000000000000000000000..6adc7e4893cbb2bff31eb822dacf96a7c9a87e27 --- /dev/null +++ b/clip/clipseg.py @@ -0,0 +1,538 @@ +import math +from os.path import basename, dirname, join, isfile +import torch +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules.activation import ReLU + + +def get_prompt_list(prompt): + if prompt == 'plain': + return ['{}'] + elif prompt == 'fixed': + return ['a photo of a {}.'] + elif prompt == 'shuffle': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + else: + raise ValueError('Invalid value for prompt') + + +def forward_multihead_attention(x, b, with_aff=False, attn_mask=None): + """ + Simplified version of multihead attention (taken from torch source code but without tons of if clauses). + The mlp and layer norm come from CLIP. + x: input. + b: multihead attention module. + """ + + x_ = b.ln_1(x) + q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1) + tgt_len, bsz, embed_dim = q.size() + + head_dim = embed_dim // b.attn.num_heads + scaling = float(head_dim) ** -0.5 + + q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + + q = q * scaling + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2 + if attn_mask is not None: + + + attn_mask_type, attn_mask = attn_mask + n_heads = attn_output_weights.size(0) // attn_mask.size(0) + attn_mask = attn_mask.repeat(n_heads, 1) + + if attn_mask_type == 'cls_token': + # the mask only affects similarities compared to the readout-token. + attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...] + # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0] + + if attn_mask_type == 'all': + # print(attn_output_weights.shape, attn_mask[:, None].shape) + attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None] + + + attn_output_weights = torch.softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = b.attn.out_proj(attn_output) + + x = x + attn_output + x = x + b.mlp(b.ln_2(x)) + + if with_aff: + return x, attn_output_weights + else: + return x + + +class CLIPDenseBase(nn.Module): + + def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens): + super().__init__() + + import clip + + # prec = torch.FloatTensor + self.clip_model, _ = clip.load(version, device='cpu', jit=False) + self.model = self.clip_model.visual + + # if not None, scale conv weights such that we obtain n_tokens. + self.n_tokens = n_tokens + + for p in self.clip_model.parameters(): + p.requires_grad_(False) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + self.reduce = nn.Linear(768, reduce_dim) + + self.prompt_list = get_prompt_list(prompt) + + # precomputed prompts + import pickle + if isfile('precomputed_prompt_vectors.pickle'): + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + else: + self.precomputed_prompts = dict() + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + + with torch.no_grad(): + + inp_size = x_inp.shape[2:] + + if self.n_tokens is not None: + stride2 = x_inp.shape[2] // self.n_tokens + conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True) + x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation) + else: + x = self.model.conv1(x_inp) # shape = [*, width, grid, grid] + + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + + standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197 + + if x.shape[1] != standard_n_tokens: + new_shape = int(math.sqrt(x.shape[1]-1)) + x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:] + else: + x = x + self.model.positional_embedding.to(x.dtype) + + x = self.model.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + activations, affinities = [], [] + for i, res_block in enumerate(self.model.transformer.resblocks): + + if mask is not None: + mask_layer, mask_type, mask_tensor = mask + if mask_layer == i or mask_layer == 'all': + # import ipdb; ipdb.set_trace() + size = int(math.sqrt(x.shape[0] - 1)) + + attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size)) + + else: + attn_mask = None + else: + attn_mask = None + + x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask) + + if i in extract_layers: + affinities += [aff_per_head] + + #if self.n_tokens is not None: + # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)] + #else: + activations += [x] + + if len(extract_layers) > 0 and i == max(extract_layers) and skip: + print('early skip') + break + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_post(x[:, 0, :]) + + if self.model.proj is not None: + x = x @ self.model.proj + + return x, activations, affinities + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + if self.shift_vector is not None: + return cond + self.shift_vector + else: + return cond + + +def clip_load_untrained(version): + assert version == 'ViT-B/16' + from clip.model import CLIP + from clip.clip import _MODELS, _download + model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval() + state_dict = model.state_dict() + + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) + + +class CLIPDensePredT(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False, + add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None, complex_trans_conv=False): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + self.rev_activations = rev_activations + + depth = len(extract_layers) + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + self.version = version + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + if fix_shift: + # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False) + self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False) + # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False) + else: + self.shift_vector = None + + if trans_conv is None: + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + else: + # explicitly define transposed conv kernel size + trans_conv_ks = (trans_conv, trans_conv) + + if not complex_trans_conv: + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + else: + assert trans_conv_ks[0] == trans_conv_ks[1] + + tp_kernels = (trans_conv_ks[0] // 4, trans_conv_ks[0] // 4) + + self.trans_conv = nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(reduce_dim, reduce_dim // 2, kernel_size=tp_kernels[0], stride=tp_kernels[0]), + nn.ReLU(), + nn.ConvTranspose2d(reduce_dim // 2, 1, kernel_size=tp_kernels[1], stride=tp_kernels[1]), + ) + +# self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + self.prompt_list = get_prompt_list(prompt) + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + _activations = activations[::-1] if not self.rev_activations else activations + + a = None + for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + a = self.trans_conv(a) + + if self.n_tokens is not None: + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, + + + +class CLIPDensePredTMasked(CLIPDensePredT): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, + prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False, + refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None): + + super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim, + n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond, + fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only, + limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration, + n_tokens=n_tokens) + + def visual_forward_masked(self, img_s, seg_s): + return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s)) + + def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False): + + if seg_s is None: + cond = cond_or_img_s + else: + img_s = cond_or_img_s + + with torch.no_grad(): + cond, _, _ = self.visual_forward_masked(img_s, seg_s) + + return super().forward(img_q, cond, return_features=return_features) + + + +class CLIPDenseBaseline(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', cond_layer=0, + extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed', + reduce_cond=None, limit_to_clip_only=False, n_tokens=None): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + device = 'cpu' + + # self.cond_layer = cond_layer + self.extract_layer = extract_layer + self.limit_to_clip_only = limit_to_clip_only + self.shift_vector = None + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + assert reduce2_dim is not None + + self.reduce2 = nn.Sequential( + nn.Linear(reduce_dim, reduce2_dim), + nn.ReLU(), + nn.Linear(reduce2_dim, reduce_dim) + ) + + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + + def forward(self, inp_image, conditional=None, return_features=False): + + inp_image = inp_image.to(self.model.positional_embedding.device) + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer]) + + a = activations[0] + a = self.reduce(a) + a = self.film_mul(cond) * a + self.film_add(cond) + + if self.reduce2 is not None: + a = self.reduce2(a) + + # the original model would execute a transformer block here + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + a = self.trans_conv(a) + + if return_features: + return a, visual_q, cond, activations + else: + return a, + + +class CLIPSegMultiLabel(nn.Module): + + def __init__(self, model) -> None: + super().__init__() + + from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC + + self.pascal_classes = VOC + + from clip.clipseg import CLIPDensePredT + from general_utils import load_model + # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False) + self.clipseg = load_model(model, strict=False) + + self.clipseg.eval() + + def forward(self, x): + + bs = x.shape[0] + out = torch.ones(21, bs, 352, 352).to(x.device) * -10 + + for class_id, class_name in enumerate(self.pascal_classes): + + fac = 3 if class_name == 'background' else 1 + + with torch.no_grad(): + pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac + + out[class_id] += pred + + + out = out.permute(1, 0, 2, 3) + + return out + + # construct output tensor + diff --git a/clip/model.py b/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..232b7792eb97440642547bd462cf128df9243933 --- /dev/null +++ b/clip/model.py @@ -0,0 +1,436 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/clip/simple_tokenizer.py b/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/clip/vitseg.py b/clip/vitseg.py new file mode 100644 index 0000000000000000000000000000000000000000..ed621431ddf930fcfa27b5929999776b96fede63 --- /dev/null +++ b/clip/vitseg.py @@ -0,0 +1,286 @@ +import math +from posixpath import basename, dirname, join +# import clip +from clip.model import convert_weights +import torch +import json +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules import activation +from torch.nn.modules.activation import ReLU +from torchvision import transforms + +normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + +from torchvision.models import ResNet + + +def process_prompts(conditional, prompt_list, conditional_map): + # DEPRECATED + + # randomly sample a synonym + words = [conditional_map[int(i)] for i in conditional] + words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] + words = [w.replace('_', ' ') for w in words] + + if prompt_list is not None: + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + else: + prompts = ['a photo of {}'] * (len(words)) + + return [promt.format(w) for promt, w in zip(prompts, words)] + + +class VITDenseBase(nn.Module): + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + with torch.no_grad(): + + x_inp = nnf.interpolate(x_inp, (384, 384)) + + x = self.model.patch_embed(x_inp) + cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.model.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.model.pos_drop(x + self.model.pos_embed) + + activations = [] + for i, block in enumerate(self.model.blocks): + x = block(x) + + if i in extract_layers: + # permute to be compatible with CLIP + activations += [x.permute(1,0,2)] + + x = self.model.norm(x) + x = self.model.head(self.model.pre_logits(x[:, 0])) + + # again for CLIP compatibility + # x = x.permute(1, 0, 2) + + return x, activations, None + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + return cond + + +class VITDensePredT(VITDenseBase): + + def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, + add_calibration=False, process_cond=None, not_pretrained=False): + super().__init__() + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + import timm + self.model = timm.create_model('vit_base_patch16_384', pretrained=True) + self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) + + for p in self.model.parameters(): + p.requires_grad_(False) + + import clip + self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) + # del self.clip_model.visual + + + self.token_shape = (14, 14) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + # self.film = AVAILABLE_BLOCKS['film'](512, 128) + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + # DEPRECATED + # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + trans_conv_ks = (16, 16) + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + if prompt == 'fixed': + self.prompt_list = ['a photo of a {}.'] + elif prompt == 'shuffle': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + elif prompt == 'shuffle_clip': + from models.clip_prompts import imagenet_templates + self.prompt_list = imagenet_templates + + if process_cond is not None: + if process_cond == 'clamp' or process_cond[0] == 'clamp': + + val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 + + def clamp_vec(x): + return torch.clamp(x, -val, val) + + self.process_cond = clamp_vec + + elif process_cond.endswith('.pth'): + + shift = torch.load(process_cond) + def add_shift(x): + return x + shift.to(x.device) + + self.process_cond = add_shift + + import pickle + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + # inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + inp_image_size = inp_image.shape[2:] + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + a = None + for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + if self.trans_conv is not None: + a = self.trans_conv(a) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + a = nnf.interpolate(a, inp_image_size) + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, diff --git a/config_colab.yaml b/config_colab.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c47f3f6f17f35eeb2089e8aba2ff42c80077ba5 --- /dev/null +++ b/config_colab.yaml @@ -0,0 +1,14 @@ +clear_output: true +force_cpu: false +max_threads: 3 +memory_limit: 0 +output_image_format: png +output_template: '{file}_{time}' +output_video_codec: libx264 +output_video_format: mp4 +provider: cuda +selected_theme: Default +server_name: '' +server_port: 0 +server_share: true +video_quality: 14 diff --git a/handler.py b/handler.py index b20d4833393779b64df7588d7c5bded68cf329c9..d68620c5d913549ce5802c23e76680eaae402ba9 100644 --- a/handler.py +++ b/handler.py @@ -31,7 +31,7 @@ import tempfile from rembg import remove import onnxruntime as ort - +import shutil device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -187,54 +187,44 @@ class EndpointHandler(): f.write("="*30 + "\n") def convert_to_playable_format(self, input_path, output_path): - command = [ - "ffmpeg", - "-i", input_path, - "-c:v", "libx264", - "-preset", "fast", - "-crf", "18", - "-y", # Overwrite output file if it exists - output_path - ] - result = subprocess.run(command, capture_output=True, text=True) + with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file: + temp_output_path = tmp_file.name + + command = f"ffmpeg -i {input_path} -c:v libx264 -preset fast -crf 18 -y {temp_output_path}" + + # Run the command with shell=True + result = subprocess.run(command, shell=True, capture_output=True, text=True) print("Conversion STDOUT:", result.stdout) print("Conversion STDERR:", result.stderr) if result.returncode != 0: raise RuntimeError(f"FFmpeg conversion failed with exit code {result.returncode}") + shutil.move(temp_output_path, output_path) + def run_rife_interpolation(self, video_path, output_path, multi=2, scale=1.0): base_dir = os.path.dirname(os.path.abspath(__file__)) directory = os.path.join(base_dir, "Practical-RIFE", "inference_video.py") model_directory = os.path.join(base_dir, "Practical-RIFE", "train_log") - command = [ - "python", - directory, - f"--video={video_path}", - f"--output={output_path}", - f"--multi={multi}", - f"--scale={scale}", - f"--model={model_directory}", - ] - - result = subprocess.run(command, capture_output=True, text=True) + command = f"python {directory} --video={video_path} --output={output_path} --multi={multi} --scale={scale} --model={model_directory}" + + # Run the command with shell=True + result = subprocess.run(command, shell=True, capture_output=True, text=True) print(result) print(result.stdout) print(result.stderr) if result.returncode != 0: raise RuntimeError(f"RIFE interpolation failed with exit code {result.returncode}") - self.convert_to_playable_format(output_path, "completed_playable.mp4") + + # Overwrite the RIFE output with the converted playable format + self.convert_to_playable_format(output_path, output_path) def speed_up_video(self, input_path, output_path, factor=4): - command = [ - "ffmpeg", - "-i", input_path, - "-filter:v", f"setpts=PTS/{factor}", - "-an", # Remove audio - output_path - ] - result = subprocess.run(command, capture_output=True, text=True) + command = f"ffmpeg -i {input_path} -filter:v setpts=PTS/{factor} -an {output_path}" + + # Run the command with shell=True + result = subprocess.run(command, shell=True, capture_output=True, text=True) print("Speed Up Video STDOUT:", result.stdout) print("Speed Up Video STDERR:", result.stderr) @@ -242,14 +232,10 @@ class EndpointHandler(): raise RuntimeError(f"FFmpeg speed up failed with exit code {result.returncode}") def slow_down_video(self, input_path, output_path, factor=4): - command = [ - "ffmpeg", - "-i", input_path, - "-filter:v", f"setpts={factor}*PTS", - "-an", # Remove audio - output_path - ] - result = subprocess.run(command, capture_output=True, text=True) + command = f"ffmpeg -i {input_path} -filter:v setpts={factor}*PTS -an {output_path}" + + # Run the command with shell=True + result = subprocess.run(command, shell=True, capture_output=True, text=True) print("Slow Down Video STDOUT:", result.stdout) print("Slow Down Video STDERR:", result.stderr) @@ -319,11 +305,10 @@ class EndpointHandler(): pose_output_path = os.path.join(temp_dir, "pose_videos") # Run the extract_dwpose_from_vid.py script - command = [ - "python", "extract_dwpose_from_vid.py", - "--video_root", video_root - ] - result = subprocess.run(command, capture_output=True, text=True) + command = f'python extract_dwpose_from_vid.py --video_root {video_root}' + + # Run the command with shell=True + result = subprocess.run(command, shell=True, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"Error running extract_dwpose_from_vid.py: {result.stderr}") @@ -377,18 +362,19 @@ class EndpointHandler(): # Perform face swapping # self.print_directory_contents(temp_dir) - # swapped_face_video_path = os.path.join(save_dir, "swapped_face_output.mp4") - # self._swap_face(cropped_face_path, animation_path, swapped_face_video_path) + swapped_face_video_path = os.path.join(save_dir, "swapped_face_output.mp4") + self._swap_face('./good_face.jpeg', animation_path, swapped_face_video_path) # Slow down the produced video by 4x self.print_directory_contents(temp_dir) slowed_down_animation_path = os.path.join(save_dir, "slowed_down_animation_output.mp4") - self.slow_down_video(animation_path, slowed_down_animation_path, factor=4) + self.slow_down_video(swapped_face_video_path, slowed_down_animation_path, factor=4) # Clear CUDA cache before RIFE interpolation torch.cuda.empty_cache() # Perform RIFE interpolation + # self.print_directory_contents(temp_dir) rife_output_path = os.path.join(save_dir, "completed_result.mp4") self.run_rife_interpolation(slowed_down_animation_path, rife_output_path, multi=2, scale=0.5) diff --git a/inference_img.py b/inference_img.py new file mode 100644 index 0000000000000000000000000000000000000000..cee947ed8a15fe782cf8097ecde0a467eb1e55a3 --- /dev/null +++ b/inference_img.py @@ -0,0 +1,118 @@ +import os +import cv2 +import torch +import argparse +from torch.nn import functional as F +import warnings +warnings.filterwarnings("ignore") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='Interpolation for a pair of images') +parser.add_argument('--img', dest='img', nargs=2, required=True) +parser.add_argument('--exp', default=4, type=int) +parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range') +parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold') +parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles') +parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files') + +args = parser.parse_args() + +try: + try: + from model.RIFE_HDv2 import Model + model = Model() + model.load_model(args.modelDir, -1) + print("Loaded v2.x HD model.") + except: + from train_log.RIFE_HDv3 import Model + model = Model() + model.load_model(args.modelDir, -1) + print("Loaded v3.x HD model.") +except: + from model.RIFE_HD import Model + model = Model() + model.load_model(args.modelDir, -1) + print("Loaded v1.x HD model") +if not hasattr(model, 'version'): + model.version = 0 +model.eval() +model.device() + +if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) + +else: + img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED) + img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED) + img0 = cv2.resize(img0, (448, 256)) + img1 = cv2.resize(img1, (448, 256)) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + +n, c, h, w = img0.shape +ph = ((h - 1) // 64 + 1) * 64 +pw = ((w - 1) // 64 + 1) * 64 +padding = (0, pw - w, 0, ph - h) +img0 = F.pad(img0, padding) +img1 = F.pad(img1, padding) + + +if args.ratio: + if model.version >= 3.9: + img_list = [img0, model.inference(img0, img1, args.ratio), img1] + else: + img0_ratio = 0.0 + img1_ratio = 1.0 + if args.ratio <= img0_ratio + args.rthreshold / 2: + middle = img0 + elif args.ratio >= img1_ratio - args.rthreshold / 2: + middle = img1 + else: + tmp_img0 = img0 + tmp_img1 = img1 + for inference_cycle in range(args.rmaxcycles): + middle = model.inference(tmp_img0, tmp_img1) + middle_ratio = ( img0_ratio + img1_ratio ) / 2 + if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2): + break + if args.ratio > middle_ratio: + tmp_img0 = middle + img0_ratio = middle_ratio + else: + tmp_img1 = middle + img1_ratio = middle_ratio + img_list.append(middle) + img_list.append(img1) +else: + if model.version >= 3.9: + img_list = [img0] + n = 2 ** args.exp + for i in range(n-1): + img_list.append(model.inference(img0, img1, (i+1) * 1. / n)) + img_list.append(img1) + else: + img_list = [img0, img1] + for i in range(args.exp): + tmp = [] + for j in range(len(img_list) - 1): + mid = model.inference(img_list[j], img_list[j + 1]) + tmp.append(img_list[j]) + tmp.append(mid) + tmp.append(img1) + img_list = tmp + +if not os.path.exists('output'): + os.mkdir('output') +for i in range(len(img_list)): + if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + else: + cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) diff --git a/inference_img_SR.py b/inference_img_SR.py new file mode 100644 index 0000000000000000000000000000000000000000..4ecf2acd4e0becdf65bfb57e01f09aa5bd0594c9 --- /dev/null +++ b/inference_img_SR.py @@ -0,0 +1,69 @@ +import os +import cv2 +import torch +import argparse +from torch.nn import functional as F +import warnings +warnings.filterwarnings("ignore") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='STVSR for a pair of images') +parser.add_argument('--img', dest='img', nargs=2, required=True) +parser.add_argument('--exp', default=2, type=int) +parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range') +parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files') + +args = parser.parse_args() + +from train_log.model import Model +model = Model() +model.device() +model.load_model('train_log') +model.eval() + +if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) +else: + img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED) + img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED) + img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + +n, c, h, w = img0.shape +ph = ((h - 1) // 32 + 1) * 32 +pw = ((w - 1) // 32 + 1) * 32 +padding = (0, pw - w, 0, ph - h) +img0 = F.pad(img0, padding) +img1 = F.pad(img1, padding) + +if args.ratio: + print('ratio={}'.format(args.ratio)) + img_list = model.inference(img0, img1, timestep=args.ratio) +else: + n = 2 ** args.exp - 1 + time_list = [0] + for i in range(n): + time_list.append((i+1) * 1. / (n+1)) + time_list.append(1) + print(time_list) + img_list = model.inference(img0, img1, timestep=time_list) + +if not os.path.exists('output'): + os.mkdir('output') +for i in range(len(img_list)): + if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + else: + cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) diff --git a/inference_video.py b/inference_video.py new file mode 100644 index 0000000000000000000000000000000000000000..854eef45c57cc6f7027346c610997c643d1e2113 --- /dev/null +++ b/inference_video.py @@ -0,0 +1,293 @@ +import os +import cv2 +import torch +import argparse +import numpy as np +from tqdm import tqdm +from torch.nn import functional as F +import warnings +import _thread +import skvideo.io +from queue import Queue, Empty +from model.pytorch_msssim import ssim_matlab + +warnings.filterwarnings("ignore") + +def transferAudio(sourceVideo, targetVideo): + import shutil + import moviepy.editor + tempAudioFileName = "./temp/audio.mkv" + + # split audio from original video file and store in "temp" directory + if True: + + # clear old "temp" directory if it exits + if os.path.isdir("temp"): + # remove temp directory + shutil.rmtree("temp") + # create new "temp" directory + os.makedirs("temp") + # extract audio from video + os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName)) + + targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] + os.rename(targetVideo, targetNoAudio) + # combine audio file and new video file + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + + if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac + tempAudioFileName = "./temp/audio.m4a" + os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName)) + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format + os.rename(targetNoAudio, targetVideo) + print("Audio transfer failed. Interpolated video will have no audio") + else: + print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") + + # remove audio-less video + os.remove(targetNoAudio) + else: + os.remove(targetNoAudio) + + # remove temp directory + shutil.rmtree("temp") + +parser = argparse.ArgumentParser(description='Interpolation for a pair of images') +parser.add_argument('--video', dest='video', type=str, default=None) +parser.add_argument('--output', dest='output', type=str, default=None) +parser.add_argument('--img', dest='img', type=str, default=None) +parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video') +parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files') +parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores') +parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video') +parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video') +parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') +parser.add_argument('--fps', dest='fps', type=int, default=None) +parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs') +parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension') +parser.add_argument('--exp', dest='exp', type=int, default=1) +parser.add_argument('--multi', dest='multi', type=int, default=2) + +args = parser.parse_args() +if args.exp != 1: + args.multi = (2 ** args.exp) +assert (not args.video is None or not args.img is None) +if args.skip: + print("skip flag is abandoned, please refer to issue #207.") +if args.UHD and args.scale==1.0: + args.scale = 0.5 +assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0] +if not args.img is None: + args.png = True + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + if(args.fp16): + torch.set_default_tensor_type(torch.cuda.HalfTensor) + +try: + from train_log.RIFE_HDv3 import Model +except: + print("Please download our model from model list") +model = Model() +if not hasattr(model, 'version'): + model.version = 0 +model.load_model(args.modelDir, -1) +print("Loaded 3.x/4.x HD model.") +model.eval() +model.device() + +if not args.video is None: + videoCapture = cv2.VideoCapture(args.video) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + videoCapture.release() + if args.fps is None: + fpsNotAssigned = True + args.fps = fps * args.multi + else: + fpsNotAssigned = False + videogen = skvideo.io.vreader(args.video) + lastframe = next(videogen) + fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') + video_path_wo_ext, ext = os.path.splitext(args.video) + print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps)) + if args.png == False and fpsNotAssigned == True: + print("The audio will be merged after interpolation process") + else: + print("Will not merge audio because using png or fps flag!") +else: + videogen = [] + for f in os.listdir(args.img): + if 'png' in f: + videogen.append(f) + tot_frame = len(videogen) + videogen.sort(key= lambda x:int(x[:-4])) + lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + videogen = videogen[1:] +h, w, _ = lastframe.shape +vid_out_name = None +vid_out = None +if args.png: + if not os.path.exists('vid_out'): + os.mkdir('vid_out') +else: + if args.output is not None: + print("Out") + vid_out_name = args.output + else: + vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext) + print("Width is ", w," and height is ", h) + vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h)) + +def clear_write_buffer(user_args, write_buffer): + cnt = 0 + while True: + item = write_buffer.get() + if item is None: + break + if user_args.png: + cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1]) + cnt += 1 + else: + vid_out.write(item[:, :, ::-1]) + +def build_read_buffer(user_args, read_buffer, videogen): + try: + for frame in videogen: + if not user_args.img is None: + frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + if user_args.montage: + frame = frame[:, left: left + w] + read_buffer.put(frame) + except: + pass + read_buffer.put(None) + +def make_inference(I0, I1, n): + global model + if model.version >= 3.9: + res = [] + for i in range(n): + res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale)) + return res + else: + middle = model.inference(I0, I1, args.scale) + if n == 1: + return [middle] + first_half = make_inference(I0, middle, n=n//2) + second_half = make_inference(middle, I1, n=n//2) + if n%2: + return [*first_half, middle, *second_half] + else: + return [*first_half, *second_half] + +def pad_image(img): + if(args.fp16): + return F.pad(img, padding).half() + else: + return F.pad(img, padding) + +if args.montage: + left = w // 4 + w = w // 2 +tmp = max(128, int(128 / args.scale)) +ph = ((h - 1) // tmp + 1) * tmp +pw = ((w - 1) // tmp + 1) * tmp +padding = (0, pw - w, 0, ph - h) +pbar = tqdm(total=tot_frame) +if args.montage: + lastframe = lastframe[:, left: left + w] +write_buffer = Queue(maxsize=500) +read_buffer = Queue(maxsize=500) +_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen)) +_thread.start_new_thread(clear_write_buffer, (args, write_buffer)) + +I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. +I1 = pad_image(I1) +temp = None # save lastframe when processing static frame + +while True: + if temp is not None: + frame = temp + temp = None + else: + frame = read_buffer.get() + if frame is None: + break + I0 = I1 + I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. + I1 = pad_image(I1) + I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + + break_flag = False + if ssim > 0.996: + frame = read_buffer.get() # read a new frame + if frame is None: + break_flag = True + frame = lastframe + else: + temp = frame + I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. + I1 = pad_image(I1) + I1 = model.inference(I0, I1, args.scale) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w] + + if ssim < 0.2: + output = [] + for i in range(args.multi - 1): + output.append(I0) + ''' + output = [] + step = 1 / args.multi + alpha = 0 + for i in range(args.multi - 1): + alpha += step + beta = 1-alpha + output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) + ''' + else: + output = make_inference(I0, I1, args.multi-1) + + if args.montage: + write_buffer.put(np.concatenate((lastframe, lastframe), 1)) + for mid in output: + mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) + write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1)) + else: + write_buffer.put(lastframe) + for mid in output: + mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) + write_buffer.put(mid[:h, :w]) + pbar.update(1) + lastframe = frame + if break_flag: + break + +if args.montage: + write_buffer.put(np.concatenate((lastframe, lastframe), 1)) +else: + write_buffer.put(lastframe) +import time +while(not write_buffer.empty()): + time.sleep(0.1) +pbar.close() +if not vid_out is None: + vid_out.release() + +# move audio to new video file if appropriate +# if args.png == False and fpsNotAssigned == True and not args.video is None: +# try: +# transferAudio(args.video, vid_out_name) +# except: +# print("Audio transfer failed. Interpolated video will have no audio") +# targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1] +# os.rename(targetNoAudio, vid_out_name) diff --git a/inference_video_enhance.py b/inference_video_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..d3076cd233fc0168d54c2a4b57393473fec6d5a5 --- /dev/null +++ b/inference_video_enhance.py @@ -0,0 +1,201 @@ +import os +import cv2 +import torch +import argparse +import numpy as np +from tqdm import tqdm +from torch.nn import functional as F +import warnings +import _thread +import skvideo.io +from queue import Queue, Empty +from model.pytorch_msssim import ssim_matlab + +warnings.filterwarnings("ignore") + +def transferAudio(sourceVideo, targetVideo): + import shutil + import moviepy.editor + tempAudioFileName = "./temp/audio.mkv" + + # split audio from original video file and store in "temp" directory + if True: + + # clear old "temp" directory if it exits + if os.path.isdir("temp"): + # remove temp directory + shutil.rmtree("temp") + # create new "temp" directory + os.makedirs("temp") + # extract audio from video + os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName)) + + targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] + os.rename(targetVideo, targetNoAudio) + # combine audio file and new video file + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + + if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac + tempAudioFileName = "./temp/audio.m4a" + os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName)) + os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) + if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format + os.rename(targetNoAudio, targetVideo) + print("Audio transfer failed. Interpolated video will have no audio") + else: + print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") + + # remove audio-less video + os.remove(targetNoAudio) + else: + os.remove(targetNoAudio) + + # remove temp directory + shutil.rmtree("temp") + +parser = argparse.ArgumentParser(description='Video SR') +parser.add_argument('--video', dest='video', type=str, default=None) +parser.add_argument('--output', dest='output', type=str, default=None) +parser.add_argument('--img', dest='img', type=str, default=None) +parser.add_argument('--model', dest='modelDir', type=str, default='train_log_SAFA', help='directory with trained model files') +parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores') +parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs') +parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension') + +args = parser.parse_args() +assert (not args.video is None or not args.img is None) +if not args.img is None: + args.png = True + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + if(args.fp16): + print('set fp16') + torch.set_default_tensor_type(torch.cuda.HalfTensor) + +try: + from train_log_SAFA.model import Model +except: + print("Please download our model from model list") +model = Model() +model.device() +model.load_model(args.modelDir) +print("Loaded SAFA model.") +model.eval() + +if not args.video is None: + videoCapture = cv2.VideoCapture(args.video) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + videoCapture.release() + fpsNotAssigned = True + videogen = skvideo.io.vreader(args.video) + lastframe = next(videogen) + fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') + video_path_wo_ext, ext = os.path.splitext(args.video) + if args.png == False and fpsNotAssigned == True: + print("The audio will be merged after interpolation process") + else: + print("Will not merge audio because using png or fps flag!") +else: + videogen = [] + for f in os.listdir(args.img): + if 'png' in f: + videogen.append(f) + tot_frame = len(videogen) + videogen.sort(key= lambda x:int(x[:-4])) + lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + videogen = videogen[1:] + +h, w, _ = lastframe.shape + +vid_out_name = None +vid_out = None +if args.png: + if not os.path.exists('vid_out'): + os.mkdir('vid_out') +else: + if args.output is not None: + vid_out_name = args.output + else: + vid_out_name = '{}_2X{}'.format(video_path_wo_ext, ext) + vid_out = cv2.VideoWriter(vid_out_name, fourcc, fps, (w, h)) + +def clear_write_buffer(user_args, write_buffer): + cnt = 0 + while True: + item = write_buffer.get() + if item is None: + break + if user_args.png: + cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1]) + cnt += 1 + else: + vid_out.write(item[:, :, ::-1]) + +def build_read_buffer(user_args, read_buffer, videogen): + for frame in videogen: + if not user_args.img is None: + frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + # if user_args.montage: + # frame = frame[:, left: left + w] + read_buffer.put(frame) + read_buffer.put(None) + +def pad_image(img): + if(args.fp16): + return F.pad(img, padding, mode='reflect').half() + else: + return F.pad(img, padding, mode='reflect') + +tmp = 64 +ph = ((h - 1) // tmp + 1) * tmp +pw = ((w - 1) // tmp + 1) * tmp +padding = (0, pw - w, 0, ph - h) +pbar = tqdm(total=tot_frame) +write_buffer = Queue(maxsize=500) +read_buffer = Queue(maxsize=500) +_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen)) +_thread.start_new_thread(clear_write_buffer, (args, write_buffer)) + +while True: + frame = read_buffer.get() + if frame is None: + break + # lastframe_2x = cv2.resize(lastframe, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + # frame_2x = cv2.resize(frame, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + I0 = pad_image(torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) + I1 = pad_image(torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) + I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + if ssim < 0.2: + out = [model.inference(I0, I0, [0])[0], model.inference(I1, I1, [0])[0]] + else: + out = model.inference(I0, I1, [0, 1]) + assert(len(out) == 2) + write_buffer.put((out[0][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) + write_buffer.put((out[1][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) + lastframe = read_buffer.get() + if lastframe is None: + break + pbar.update(2) + +import time +while(not write_buffer.empty()): + time.sleep(0.1) +pbar.close() +if not vid_out is None: + vid_out.release() + +# move audio to new video file if appropriate +if args.png == False and fpsNotAssigned == True and not args.video is None: + try: + transferAudio(args.video, vid_out_name) + except: + print("Audio transfer failed. Interpolated video will have no audio") + targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1] + os.rename(targetNoAudio, vid_out_name) diff --git a/installer/installer.py b/installer/installer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab45c18c2288f85d6c25de2923e2d42d561a15b7 --- /dev/null +++ b/installer/installer.py @@ -0,0 +1,87 @@ +import argparse +import glob +import os +import shutil +import site +import subprocess +import sys + + +script_dir = os.getcwd() + + +def run_cmd(cmd, capture_output=False, env=None): + # Run shell commands + return subprocess.run(cmd, shell=True, capture_output=capture_output, env=env) + + +def check_env(): + # If we have access to conda, we are probably in an environment + conda_not_exist = run_cmd("conda", capture_output=True).returncode + if conda_not_exist: + print("Conda is not installed. Exiting...") + sys.exit() + + # Ensure this is a new environment and not the base environment + if os.environ["CONDA_DEFAULT_ENV"] == "base": + print("Create an environment for this project and activate it. Exiting...") + sys.exit() + + +def install_dependencies(): + global MY_PATH + + # Install Git and clone repo + run_cmd("conda install -y -k git") + run_cmd("git clone https://github.com/C0untFloyd/roop-unleashed.git") + os.chdir(MY_PATH) + run_cmd("git checkout c8643a0532f09f84397aaacf526e66db6455d399") + # Installs dependencies from requirements.txt + run_cmd("python -m pip install -r requirements.txt") + + + +def update_dependencies(): + global MY_PATH + + os.chdir(MY_PATH) + # do a hard reset for to update even if there are local changes + run_cmd("git fetch --all") + run_cmd("git reset --hard origin/main") + run_cmd("git pull") + # Installs/Updates dependencies from all requirements.txt + run_cmd("python -m pip install -r requirements.txt") + + +def start_app(): + global MY_PATH + + os.chdir(MY_PATH) + # forward commandline arguments + sys.argv.pop(0) + args = ' '.join(sys.argv) + print("Launching App") + run_cmd(f'python run.py {args}') + + +if __name__ == "__main__": + global MY_PATH + + MY_PATH = "roop-unleashed" + + + # Verifies we are in a conda environment + check_env() + + # If webui has already been installed, skip and run + if not os.path.exists(MY_PATH): + install_dependencies() + else: + # moved update from batch to here, because of batch limitations + updatechoice = input("Check for Updates? [y/n]").lower() + if updatechoice == "y": + update_dependencies() + + # Run the model with webui + os.chdir(script_dir) + start_app() diff --git a/installer/windows_run.bat b/installer/windows_run.bat new file mode 100644 index 0000000000000000000000000000000000000000..5441a00d9b98d305caffe4c2391c09f371e58c4c --- /dev/null +++ b/installer/windows_run.bat @@ -0,0 +1,99 @@ +@echo off + +REM No CLI arguments supported anymore +set COMMANDLINE_ARGS= + +cd /D "%~dp0" + +echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniconda which can not be silently installed under a path with spaces. && goto end + +set PATH=%PATH%;%SystemRoot%\system32 + +@rem config +set INSTALL_DIR=%cd%\installer_files +set CONDA_ROOT_PREFIX=%cd%\installer_files\conda +set INSTALL_ENV_DIR=%cd%\installer_files\env +set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe +set FFMPEG_DOWNLOAD_URL=https://github.com/GyanD/codexffmpeg/releases/download/2023-06-21-git-1bcb8a7338/ffmpeg-2023-06-21-git-1bcb8a7338-essentials_build.zip +set INSTALL_FFMPEG_DIR=%cd%\installer_files\ffmpeg +set INSIGHTFACE_PACKAGE_URL=https://github.com/C0untFloyd/roop-unleashed/releases/download/3.6.6/insightface-0.7.3-cp310-cp310-win_amd64.whl +set INSIGHTFACE_PACKAGE_PATH=%INSTALL_DIR%\insightface-0.7.3-cp310-cp310-win_amd64.whl + +set conda_exists=F +set ffmpeg_exists=F + +@rem figure out whether git and conda needs to be installed +call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1 +if "%ERRORLEVEL%" EQU "0" set conda_exists=T + +@rem Check if FFmpeg is already in PATH +where ffmpeg >nul 2>&1 +if "%ERRORLEVEL%" EQU "0" ( + echo FFmpeg is already installed. + set ffmpeg_exists=T +) + +@rem (if necessary) install git and conda into a contained environment + +@rem download conda +if "%conda_exists%" == "F" ( + echo Downloading Miniconda from %MINICONDA_DOWNLOAD_URL% to %INSTALL_DIR%\miniconda_installer.exe + mkdir "%INSTALL_DIR%" + call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe" || ( echo. && echo Miniconda failed to download. && goto end ) + echo Installing Miniconda to %CONDA_ROOT_PREFIX% + start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX% + + @rem test the conda binary + echo Miniconda version: + call "%CONDA_ROOT_PREFIX%\_conda.exe" --version || ( echo. && echo Miniconda not found. && goto end ) +) + +@rem create the installer env +if not exist "%INSTALL_ENV_DIR%" ( + echo Creating Conda Environment + call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10 || ( echo. && echo ERROR: Conda environment creation failed. && goto end ) + @rem check if conda environment was actually created + if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo ERROR: Conda environment is empty. && goto end ) + @rem activate installer env + call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo ERROR: Miniconda hook not found. && goto end ) + @rem Download insightface package + echo Downloading insightface package from %INSIGHTFACE_PACKAGE_URL% to %INSIGHTFACE_PACKAGE_PATH% + call curl -Lk "%INSIGHTFACE_PACKAGE_URL%" > "%INSIGHTFACE_PACKAGE_PATH%" || ( echo. && echo ERROR: Insightface package failed to download. && goto end ) + @rem install insightface package using pip + echo Installing insightface package + call pip install "%INSIGHTFACE_PACKAGE_PATH%" || ( echo. && echo ERROR: Insightface package installation failed. && goto end ) +) + +@rem Download and install FFmpeg if not already installed +if "%ffmpeg_exists%" == "F" ( + if not exist "%INSTALL_FFMPEG_DIR%" ( + echo Downloading ffmpeg from %FFMPEG_DOWNLOAD_URL% to %INSTALL_DIR% + call curl -Lk "%FFMPEG_DOWNLOAD_URL%" > "%INSTALL_DIR%\ffmpeg.zip" || ( echo. && echo ffmpeg failed to download. && goto end ) + call powershell -command "Expand-Archive -Force '%INSTALL_DIR%\ffmpeg.zip' '%INSTALL_DIR%\'" + cd "installer_files" + setlocal EnableExtensions EnableDelayedExpansion + for /f "tokens=*" %%f in ('dir /s /b /ad "ffmpeg\*"') do ( + ren "%%f" "ffmpeg" + ) + endlocal + setx PATH "%INSTALL_FFMPEG_DIR%\bin\;%PATH%" + echo To use videos, you need to restart roop after this installation. + cd .. + ) +) else ( + echo Skipping FFmpeg installation as it is already available. +) + +@rem setup installer env +@rem check if conda environment was actually created +if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo ERROR: Conda environment is empty. && goto end ) +@rem activate installer env +call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo ERROR: Miniconda hook not found. && goto end ) +echo Launching roop unleashed +call python installer.py %COMMANDLINE_ARGS% + +echo. +echo Done! + +:end +pause diff --git a/model/__pycache__/loss.cpython-310.pyc b/model/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe05e6558d719f6ef013df2130686d2fae3211ae Binary files /dev/null and b/model/__pycache__/loss.cpython-310.pyc differ diff --git a/model/__pycache__/warplayer.cpython-310.pyc b/model/__pycache__/warplayer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91379f162fff78c380af33f0b43739e2daaa8ea2 Binary files /dev/null and b/model/__pycache__/warplayer.cpython-310.pyc differ diff --git a/model/loss.py b/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..72e5de6af050df7d55c2871a69637077970ddfb9 --- /dev/null +++ b/model/loss.py @@ -0,0 +1,128 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class EPE(nn.Module): + def __init__(self): + super(EPE, self).__init__() + + def forward(self, flow, gt, loss_mask): + loss_map = (flow - gt.detach()) ** 2 + loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 + return (loss_map * loss_mask) + + +class Ternary(nn.Module): + def __init__(self): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape( + (patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf**2) + return transf_norm + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + def forward(self, img0, img1): + img0 = self.transform(self.rgb2gray(img0)) + img1 = self.transform(self.rgb2gray(img1)) + return self.hamming(img0, img1) * self.valid_mask(img0, 1) + + +class SOBEL(nn.Module): + def __init__(self): + super(SOBEL, self).__init__() + self.kernelX = torch.tensor([ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1], + ]).float() + self.kernelY = self.kernelX.clone().T + self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) + self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) + + def forward(self, pred, gt): + N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] + img_stack = torch.cat( + [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0) + sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) + sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) + pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:] + pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:] + + L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y) + loss = (L1X+L1Y) + return loss + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std) + self.weight.data = torch.eye(c).view(c, c, 1, 1) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean) + self.requires_grad = False + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, rank=0): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + pretrained = True + self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X, Y, indices=None): + X = self.normalize(X) + Y = self.normalize(Y) + indices = [2, 7, 12, 21, 30] + weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] + k = 0 + loss = 0 + for i in range(indices[-1]): + X = self.vgg_pretrained_features[i](X) + Y = self.vgg_pretrained_features[i](Y) + if (i+1) in indices: + loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 + k += 1 + return loss + +if __name__ == '__main__': + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal( + 0, 1, (3, 3, 256, 256))).float().to(device) + ternary_loss = Ternary() + print(ternary_loss(img0, img1).shape) diff --git a/model/pytorch_msssim/__init__.py b/model/pytorch_msssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d30326188cf6afacf2fc84c7ae18efe14dae2e --- /dev/null +++ b/model/pytorch_msssim/__init__.py @@ -0,0 +1,200 @@ +import torch +import torch.nn.functional as F +from math import exp +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + +def create_window_3d(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) + return window + + +def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) + mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq + sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window_3d(real_size, channel=1).to(img1.device) + # Channel is set to 1 since we consider color images as volumetric images + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) + mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq + sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq + sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs ** weights + pow2 = mssim ** weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +# Classes to re-use window +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 3 channel for SSIM + self.channel = 3 + self.window = create_window(window_size, channel=self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) + self.window = window + self.channel = channel + + _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + dssim = (1 - _ssim) / 2 + return dssim + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super(MSSSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) diff --git a/model/pytorch_msssim/__pycache__/__init__.cpython-310.pyc b/model/pytorch_msssim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce9bdfe8f7613cf8969b4b57516f89709b4dc59a Binary files /dev/null and b/model/pytorch_msssim/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/warplayer.py b/model/warplayer.py new file mode 100644 index 0000000000000000000000000000000000000000..21b0b904cf71b297fd43813134c57d13a3ae9e4a --- /dev/null +++ b/model/warplayer.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) diff --git a/models/CLIP/rd64-uni-refined.pth b/models/CLIP/rd64-uni-refined.pth new file mode 100644 index 0000000000000000000000000000000000000000..1004abde5a060f41b188410756adb7cc3ea379ea --- /dev/null +++ b/models/CLIP/rd64-uni-refined.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4956f9a7978a75630b08c9d6ec075b7c51cf43b4751b686e3a011d4012ddc9d +size 4720707 diff --git a/models/CodeFormer/CodeFormerv0.1.onnx b/models/CodeFormer/CodeFormerv0.1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6368465c9df3b6e698faec3b47793bb7e602e0e2 --- /dev/null +++ b/models/CodeFormer/CodeFormerv0.1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9aa48fc4b21224d85784c9a58885201284ec8e590b988126db2c07495b421d36 +size 376821951 diff --git a/models/DMDNet.pth b/models/DMDNet.pth new file mode 100644 index 0000000000000000000000000000000000000000..969651e65d2a14acd530f394a12b2675edbc742c --- /dev/null +++ b/models/DMDNet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70daeb4b1fd10f241043b587d892a941f2651d7322db02f06ff64b166537f65c +size 603684323 diff --git a/models/Frame/deoldify_artistic.onnx b/models/Frame/deoldify_artistic.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05aa08f4872ea5bd6830e9b5ec7a5d23982c923b --- /dev/null +++ b/models/Frame/deoldify_artistic.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be026e17c47c85527b3084cacad352f7ca0e021c33aa827062c5997ebe72c61f +size 255024891 diff --git a/models/Frame/deoldify_stable.onnx b/models/Frame/deoldify_stable.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa36af257e550ba1fbaf613af40b9113b505f7b9 --- /dev/null +++ b/models/Frame/deoldify_stable.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98d69dbecde018fe3d630a35ac850ac590b23e359c8349d8404b467bbfe4a0b9 +size 873359997 diff --git a/models/Frame/isnet-general-use.onnx b/models/Frame/isnet-general-use.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aae8625d60df68f7a2c7fa770814a3e6eb30612a --- /dev/null +++ b/models/Frame/isnet-general-use.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60920e99c45464f2ba57bee2ad08c919a52bbf852739e96947fbb4358c0d964a +size 178648008 diff --git a/models/Frame/lsdir_x4.onnx b/models/Frame/lsdir_x4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4138f44f53bd574e61fea87dbaa5aa0a8851617c --- /dev/null +++ b/models/Frame/lsdir_x4.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c0073607ab48e91a56a180f6928597362fef9f0924cc91325aab8ce8cf1032c +size 66938051 diff --git a/models/Frame/real_esrgan_x2.onnx b/models/Frame/real_esrgan_x2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e8ffe3dfe6111fd51e8206cdf8c1cc986702f1bb --- /dev/null +++ b/models/Frame/real_esrgan_x2.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28e6925e46301ba7a4cbfeaae41cefe043dd5941423094a3db8b176d837bf1dd +size 69524246 diff --git a/models/Frame/real_esrgan_x4.onnx b/models/Frame/real_esrgan_x4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..21b902c065cdbfea1c30e7f457c1e2438b675ccc --- /dev/null +++ b/models/Frame/real_esrgan_x4.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4139cc1585d04851ccd41570b0f76e775c96e064ca292d5372b6031704dda0d3 +size 69464831 diff --git a/models/GFPGANv1.4.onnx b/models/GFPGANv1.4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..70ce511a0f7073017ee8d1a12dec525047ae358f --- /dev/null +++ b/models/GFPGANv1.4.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5060d6c8d84851bbb8da630bea59b56414b49923a2b9304fb08f72d4c98f0aeb +size 340256688 diff --git a/models/GPEN-BFR-512.onnx b/models/GPEN-BFR-512.onnx new file mode 100644 index 0000000000000000000000000000000000000000..13eec093eb1b7133416a5bc9959b1fdc07987ba2 --- /dev/null +++ b/models/GPEN-BFR-512.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0960f836488735444d508b588e44fb5dfd19c68fde9163ad7878aa24d1d5115e +size 284250449 diff --git a/models/buffalo_l.zip b/models/buffalo_l.zip new file mode 100644 index 0000000000000000000000000000000000000000..3c8de83bd164f9ac2abc7eb4486ff01b1fd3af4f --- /dev/null +++ b/models/buffalo_l.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80ffe37d8a5940d59a7384c201a2a38d4741f2f3c51eef46ebb28218a7b0ca2f +size 288621354 diff --git a/models/buffalo_l/1k3d68.onnx b/models/buffalo_l/1k3d68.onnx new file mode 100644 index 0000000000000000000000000000000000000000..221aa2f02a6faccddb2723529e1f93c7db2edbdc --- /dev/null +++ b/models/buffalo_l/1k3d68.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc +size 143607619 diff --git a/models/buffalo_l/2d106det.onnx b/models/buffalo_l/2d106det.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cdb163d88b5f51396855ebc795e0114322c98b6b --- /dev/null +++ b/models/buffalo_l/2d106det.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +size 5030888 diff --git a/models/buffalo_l/det_10g.onnx b/models/buffalo_l/det_10g.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa586e034379fa5ea5babc8aa73d47afcd0fa6c2 --- /dev/null +++ b/models/buffalo_l/det_10g.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +size 16923827 diff --git a/models/buffalo_l/genderage.onnx b/models/buffalo_l/genderage.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fcf638481cea978e99ddabd914ccd3b70c8401cb --- /dev/null +++ b/models/buffalo_l/genderage.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +size 1322532 diff --git a/models/buffalo_l/w600k_r50.onnx b/models/buffalo_l/w600k_r50.onnx new file mode 100644 index 0000000000000000000000000000000000000000..571d2bb9ffd76399b23260620b9101b20bcc4e99 --- /dev/null +++ b/models/buffalo_l/w600k_r50.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43 +size 174383860 diff --git a/models/restoreformer_plus_plus.onnx b/models/restoreformer_plus_plus.onnx new file mode 100644 index 0000000000000000000000000000000000000000..54dbba7932a864c0820b036a045ea0774f5370de --- /dev/null +++ b/models/restoreformer_plus_plus.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4db5a89902b6a2d452446f5721245a6f7185f699b6aec7b77285adb4d504337 +size 294264812 diff --git a/models/xseg.onnx b/models/xseg.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6d6d3e341bb1194d5a3da18b776ceec79d455869 --- /dev/null +++ b/models/xseg.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b57328efcb839d85973164b617ceee9dfe6cfcb2c82e8a033bba9f4f09b27e5 +size 70327737 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..64218bc23688632a08c98ec4a0451ed46f8ed5e5 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +check_untyped_defs = True +disallow_any_generics = True +disallow_untyped_calls = True +disallow_untyped_defs = True +ignore_missing_imports = True +strict_optional = False diff --git a/pretrained_weights/.huggingface/.gitignore b/pretrained_weights/.huggingface/.gitignore deleted file mode 100644 index f59ec20aabf5842d237244ece8c81ab184faeac1..0000000000000000000000000000000000000000 --- a/pretrained_weights/.huggingface/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* \ No newline at end of file diff --git a/pretrained_weights/.huggingface/download/denoising_unet.pth.metadata b/pretrained_weights/.huggingface/download/denoising_unet.pth.metadata deleted file mode 100644 index eafbd6bd1a6f7bd7e71a7083abdac124c3628a19..0000000000000000000000000000000000000000 --- a/pretrained_weights/.huggingface/download/denoising_unet.pth.metadata +++ /dev/null @@ -1,3 +0,0 @@ -7fd8ff9dfb0e75c7b7fa2689e4795347b5996b9a -b9e5a2c34fac369e8a922972ca2210916c6af175a0dad907deccf6235816ad52 -1717139633.586755 diff --git a/pretrained_weights/.huggingface/download/image_encoder/config.json.metadata b/pretrained_weights/.huggingface/download/image_encoder/config.json.metadata deleted file mode 100644 index 9b6fa5f50758ee2dc58094bcdf98e83b4255127e..0000000000000000000000000000000000000000 --- a/pretrained_weights/.huggingface/download/image_encoder/config.json.metadata +++ /dev/null @@ -1,3 +0,0 @@ -42bc0ee1726b141d49f519a6ea02ccfbf073db2e -251e37d8a59724357a8887da1716fad7b791b9c0 -1717139565.2024844 diff --git a/pretrained_weights/.huggingface/download/image_encoder/pytorch_model.bin.lock b/pretrained_weights/.huggingface/download/image_encoder/pytorch_model.bin.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/.huggingface/download/image_encoder/pytorch_model.bin.metadata b/pretrained_weights/.huggingface/download/image_encoder/pytorch_model.bin.metadata deleted file mode 100644 index 461e06d831e4bb771d17f6748f4f1b09aad032b9..0000000000000000000000000000000000000000 --- a/pretrained_weights/.huggingface/download/image_encoder/pytorch_model.bin.metadata +++ /dev/null @@ -1,3 +0,0 @@ -42bc0ee1726b141d49f519a6ea02ccfbf073db2e -89d2aa29b5fdf64f3ad4f45fb4227ea98bc45156bbae673b85be1af7783dbabb -1717139581.629301 diff --git a/pretrained_weights/.huggingface/download/motion_module.pth.lock b/pretrained_weights/.huggingface/download/motion_module.pth.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/.huggingface/download/motion_module.pth.metadata b/pretrained_weights/.huggingface/download/motion_module.pth.metadata deleted file mode 100644 index 3e1964efd150f1abc1562c24b0fca1f5dbf14f90..0000000000000000000000000000000000000000 --- a/pretrained_weights/.huggingface/download/motion_module.pth.metadata +++ /dev/null @@ -1,3 +0,0 @@ -7fd8ff9dfb0e75c7b7fa2689e4795347b5996b9a -0d11e01a281b39880da2efeea892215c1313e5713fca3d100a7fbb72ee312ef9 -1717139665.959808 diff --git a/pretrained_weights/.huggingface/download/pose_guider.pth.lock b/pretrained_weights/.huggingface/download/pose_guider.pth.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/.huggingface/download/pose_guider.pth.metadata b/pretrained_weights/.huggingface/download/pose_guider.pth.metadata deleted file mode 100644 index 499916acab2de9a6c9333d9446eba14b71249711..0000000000000000000000000000000000000000 --- a/pretrained_weights/.huggingface/download/pose_guider.pth.metadata +++ /dev/null @@ -1,3 +0,0 @@ -7fd8ff9dfb0e75c7b7fa2689e4795347b5996b9a -1a8b7c1b4db92980fd977b4fd003c1396bbae9a9cdea00c35d452136d5e4f488 -1717139666.369367 diff --git a/pretrained_weights/.huggingface/download/reference_unet.pth.lock b/pretrained_weights/.huggingface/download/reference_unet.pth.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/.huggingface/download/reference_unet.pth.metadata b/pretrained_weights/.huggingface/download/reference_unet.pth.metadata deleted file mode 100644 index b865cbffd8b52d79153794e94a5ed3c4ba8448d6..0000000000000000000000000000000000000000 --- a/pretrained_weights/.huggingface/download/reference_unet.pth.metadata +++ /dev/null @@ -1,3 +0,0 @@ -7fd8ff9dfb0e75c7b7fa2689e4795347b5996b9a -beddccb08d49a8b29b0f4d6d456c6521d4382a8d8d48884fa60ba8802509c214 -1717139727.397555 diff --git a/pretrained_weights/DWPose/.huggingface/.gitignore b/pretrained_weights/DWPose/.huggingface/.gitignore deleted file mode 100644 index f59ec20aabf5842d237244ece8c81ab184faeac1..0000000000000000000000000000000000000000 --- a/pretrained_weights/DWPose/.huggingface/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* \ No newline at end of file diff --git a/pretrained_weights/DWPose/.huggingface/download/dw-ll_ucoco_384.onnx.lock b/pretrained_weights/DWPose/.huggingface/download/dw-ll_ucoco_384.onnx.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/DWPose/.huggingface/download/dw-ll_ucoco_384.onnx.metadata b/pretrained_weights/DWPose/.huggingface/download/dw-ll_ucoco_384.onnx.metadata deleted file mode 100644 index 4d3e0ecce6244949f19b42a51686935ac0c681d4..0000000000000000000000000000000000000000 --- a/pretrained_weights/DWPose/.huggingface/download/dw-ll_ucoco_384.onnx.metadata +++ /dev/null @@ -1,3 +0,0 @@ -1a7144101628d69ee7a3768d1ee3a094070dc388 -724f4ff2439ed61afb86fb8a1951ec39c6220682803b4a8bd4f598cd913b1843 -1717139583.58808 diff --git a/pretrained_weights/DWPose/.huggingface/download/yolox_l.onnx.lock b/pretrained_weights/DWPose/.huggingface/download/yolox_l.onnx.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/DWPose/.huggingface/download/yolox_l.onnx.metadata b/pretrained_weights/DWPose/.huggingface/download/yolox_l.onnx.metadata deleted file mode 100644 index 9dd18348f719ba572c059d6b335190cb181f8ade..0000000000000000000000000000000000000000 --- a/pretrained_weights/DWPose/.huggingface/download/yolox_l.onnx.metadata +++ /dev/null @@ -1,3 +0,0 @@ -1a7144101628d69ee7a3768d1ee3a094070dc388 -7860ae79de6c89a3c1eb72ae9a2756c0ccfbe04b7791bb5880afabd97855a411 -1717139586.6107028 diff --git a/pretrained_weights/sd-vae-ft-mse/.huggingface/.gitignore b/pretrained_weights/sd-vae-ft-mse/.huggingface/.gitignore deleted file mode 100644 index f59ec20aabf5842d237244ece8c81ab184faeac1..0000000000000000000000000000000000000000 --- a/pretrained_weights/sd-vae-ft-mse/.huggingface/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* \ No newline at end of file diff --git a/pretrained_weights/sd-vae-ft-mse/.huggingface/download/config.json.lock b/pretrained_weights/sd-vae-ft-mse/.huggingface/download/config.json.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/sd-vae-ft-mse/.huggingface/download/config.json.metadata b/pretrained_weights/sd-vae-ft-mse/.huggingface/download/config.json.metadata deleted file mode 100644 index 43fc0b594ff39b4f1e3ad54dc3ddc5f443247b7b..0000000000000000000000000000000000000000 --- a/pretrained_weights/sd-vae-ft-mse/.huggingface/download/config.json.metadata +++ /dev/null @@ -1,3 +0,0 @@ -31f26fdeee1355a5c34592e401dd41e45d25a493 -0db26717579be63eb0ddbf15b43faa43700dfe5a -1717139586.8893416 diff --git a/pretrained_weights/sd-vae-ft-mse/.huggingface/download/diffusion_pytorch_model.bin.lock b/pretrained_weights/sd-vae-ft-mse/.huggingface/download/diffusion_pytorch_model.bin.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/sd-vae-ft-mse/.huggingface/download/diffusion_pytorch_model.bin.metadata b/pretrained_weights/sd-vae-ft-mse/.huggingface/download/diffusion_pytorch_model.bin.metadata deleted file mode 100644 index f52552a0f352f33a03e5aece3a5a03aa8a365999..0000000000000000000000000000000000000000 --- a/pretrained_weights/sd-vae-ft-mse/.huggingface/download/diffusion_pytorch_model.bin.metadata +++ /dev/null @@ -1,3 +0,0 @@ -31f26fdeee1355a5c34592e401dd41e45d25a493 -1b4889b6b1d4ce7ae320a02dedaeff1780ad77d415ea0d744b476155c6377ddc -1717139591.4496999 diff --git a/pretrained_weights/stable-diffusion-v1-5/.huggingface/.gitignore b/pretrained_weights/stable-diffusion-v1-5/.huggingface/.gitignore deleted file mode 100644 index f59ec20aabf5842d237244ece8c81ab184faeac1..0000000000000000000000000000000000000000 --- a/pretrained_weights/stable-diffusion-v1-5/.huggingface/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* \ No newline at end of file diff --git a/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/config.json.lock b/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/config.json.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/config.json.metadata b/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/config.json.metadata deleted file mode 100644 index 0bdab58a5b9815b2b387adc8bb3fd136589c97a8..0000000000000000000000000000000000000000 --- a/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/config.json.metadata +++ /dev/null @@ -1,3 +0,0 @@ -1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9 -1a02ee8abc93e840ffbcb2d68b66ccbcb74b3ab3 -1717139517.935449 diff --git a/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/diffusion_pytorch_model.bin.lock b/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/diffusion_pytorch_model.bin.lock deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/diffusion_pytorch_model.bin.metadata b/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/diffusion_pytorch_model.bin.metadata deleted file mode 100644 index dcb5fa5749e8ae076703d4312d49b1b3ce3800e8..0000000000000000000000000000000000000000 --- a/pretrained_weights/stable-diffusion-v1-5/.huggingface/download/unet/diffusion_pytorch_model.bin.metadata +++ /dev/null @@ -1,3 +0,0 @@ -1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9 -c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4 -1717139564.9288538 diff --git a/requirements.txt b/requirements.txt index b5bef6b2e321566bc7aed3fda4ab184ba0e056b3..26d424de803dc18fa918ff2501bf6369e6a6eb1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,62 +1,19 @@ --extra-index-url https://download.pytorch.org/whl/cu118 -numpy==1.23.5 -opencv-python==4.8.1.78 -onnx==1.14.0 +numpy==1.26.4 +gradio==4.29.0 +opencv-python==4.9.0.80 +onnx==1.16.0 insightface==0.7.3 -psutil==5.9.5 -tk==0.1.0 -pillow==9.5.0 -torch==2.0.1+cu118; sys_platform != 'darwin' -torch==2.0.1; sys_platform == 'darwin' -torchvision==0.15.2+cu118; sys_platform != 'darwin' -torchvision==0.15.2; sys_platform == 'darwin' -onnxruntime==1.16.3; sys_platform == 'darwin' and platform_machine != 'arm64' -onnxruntime-silicon==1.13.1; sys_platform == 'darwin' and platform_machine == 'arm64' -onnxruntime-gpu==1.16.3; sys_platform != 'darwin' -tensorflow==2.13.0rc1; sys_platform == 'darwin' -tensorflow==2.12.0; sys_platform != 'darwin' -opennsfw2==0.10.2 -protobuf==4.23.2 -tqdm==4.66.1 -gfpgan==1.3.8 -gradio==3.41.2 -onnxruntime-coreml==1.13.1; python_version == '3.9' and sys_platform == 'darwin' and platform_machine != 'arm64' -transformers==4.41.1 -controlnet-aux==0.0.7 - -# Add additional dependencies -diffusers==0.24.0 -omegaconf==2.2.3 - -# Face swap related dependencies -facenet-pytorch==2.5.2 -dlib==19.22.0 - -# Stuff huggingface is complaining about -einops==0.4.1 -av==11.0.0 - -# Additional dependencies from the first list not present in the second list -accelerate==0.21.0 -clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a -decord==0.6.0 -gradio_client==0.5.0 -imageio==2.33.0 -imageio-ffmpeg==0.4.9 -scikit-image==0.21.0 -scikit-learn==1.3.2 -scipy==1.11.4 -torchdiffeq==0.2.3 -torchmetrics==1.2.1 -torchsde==0.2.5 - - -# Additional dependencies for RIFE -sk-video==1.1.10 -moviepy==1.0.3 - -requests==2.32.3 - - -rembg \ No newline at end of file +psutil==5.9.6 +torch==2.1.2+cu118; sys_platform != 'darwin' +torch==2.1.2; sys_platform == 'darwin' +torchvision==0.16.2+cu118; sys_platform != 'darwin' +torchvision==0.16.2; sys_platform == 'darwin' +onnxruntime==1.17.1; sys_platform == 'darwin' and platform_machine != 'arm64' +onnxruntime-silicon==1.17.1; sys_platform == 'darwin' and platform_machine == 'arm64' +onnxruntime-gpu==1.17.1; sys_platform != 'darwin' +tqdm==4.66.4 +ftfy +regex +pyvirtualcam diff --git a/roop-unleashed b/roop-unleashed deleted file mode 160000 index ed6e3dbcf875213051dbc3b095e570afd3557463..0000000000000000000000000000000000000000 --- a/roop-unleashed +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ed6e3dbcf875213051dbc3b095e570afd3557463 diff --git a/roop-unleashed.ipynb b/roop-unleashed.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5bfa7b67796c0d150d83b4674e804e4491116bb0 --- /dev/null +++ b/roop-unleashed.ipynb @@ -0,0 +1,208 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "collapsed_sections": [ + "UdQ1VHdI8lCf" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Colab for roop-unleashed - Gradio version\n", + "https://github.com/C0untFloyd/roop-unleashed\n" + ], + "metadata": { + "id": "G9BdiCppV6AS" + } + }, + { + "cell_type": "markdown", + "source": [ + "Install CUDA V11.8 on Google Cloud Compute" + ], + "metadata": { + "id": "CanIXgLJgaOj" + } + }, + { + "cell_type": "code", + "source": [ + "!apt-get -y update\n", + "!apt-get -y install cuda-toolkit-11-8\n", + "import os\n", + "os.environ[\"LD_LIBRARY_PATH\"] += \":\" + \"/usr/local/cuda-11/lib64\"\n", + "os.environ[\"LD_LIBRARY_PATH\"] += \":\" + \"/usr/local/cuda-11.8/lib64\"" + ], + "metadata": { + "id": "96GE4UgYg3Ej" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Installing & preparing requirements" + ], + "metadata": { + "id": "0ZYRNb0AWLLW" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t1yPuhdySqCq" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/C0untFloyd/roop-unleashed.git\n", + "%cd roop-unleashed\n", + "!mv config_colab.yaml config.yaml\n", + "!pip install pip install -r requirements.txt" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Running roop-unleashed with default config" + ], + "metadata": { + "id": "u_4JQiSlV9Fi" + } + }, + { + "cell_type": "code", + "source": [ + "!python run.py" + ], + "metadata": { + "id": "Is6U2huqSzLE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Download generated images folder\n", + "(only needed if you want to zip the generated output)" + ], + "metadata": { + "id": "UdQ1VHdI8lCf" + } + }, + { + "cell_type": "code", + "source": [ + "import shutil\n", + "import os\n", + "from google.colab import files\n", + "\n", + "def zip_directory(directory_path, zip_path):\n", + " shutil.make_archive(zip_path, 'zip', directory_path)\n", + "\n", + "# Set the directory path you want to download\n", + "directory_path = '/content/roop-unleashed/output'\n", + "\n", + "# Set the zip file name\n", + "zip_filename = 'fake_output.zip'\n", + "\n", + "# Zip the directory\n", + "zip_directory(directory_path, zip_filename)\n", + "\n", + "# Download the zip file\n", + "files.download(zip_filename+'.zip')\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "oYjWveAmw10X", + "outputId": "5b4c3650-f951-434a-c650-5525a8a70c1e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "download(\"download_789eab11-93d2-4880-adf3-6aceee0cc5f9\", \"fake_output.zip.zip\", 80125)" + ] + }, + "metadata": {} + } + ] + } + ] +} diff --git a/roop-unleashed/.flake8 b/roop-unleashed/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..43a1b76932b6cb62486ec7e925caf1853693a403 --- /dev/null +++ b/roop-unleashed/.flake8 @@ -0,0 +1,3 @@ +[flake8] +select = E3, E4, F +per-file-ignores = roop/core.py:E402 \ No newline at end of file diff --git a/roop-unleashed/LICENSE b/roop-unleashed/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52 --- /dev/null +++ b/roop-unleashed/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/roop-unleashed/README.md b/roop-unleashed/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d49580a79fd546e5cd9f09583802dd5ce27ac26c --- /dev/null +++ b/roop-unleashed/README.md @@ -0,0 +1,156 @@ +# roop-unleashed + +[Changelog](#changelog) โ€ข [Usage](#usage) โ€ข [Wiki](https://github.com/C0untFloyd/roop-unleashed/wiki) + + +Uncensored Deepfakes for images and videos without training and an easy-to-use GUI. + + +![Screen](https://github.com/C0untFloyd/roop-unleashed/assets/131583554/6ee6860d-efbe-4337-8c62-a67598863637) + +### Features + +- Platform-independant Browser GUI +- Selection of multiple input/output faces in one go +- Many different swapping modes, first detected, face selections, by gender +- Batch processing of images/videos +- Masking of face occluders using text prompts or automatically +- Optional Face Upscaler/Restoration using different enhancers +- Preview swapping from different video frames +- Live Fake Cam using your webcam +- Extras Tab for cutting videos etc. +- Settings - storing configuration for next session +- Theme Support + +and lots more... + + +## Disclaimer + +This project is for technical and academic use only. +Users of this software are expected to use this software responsibly while abiding the local law. If a face of a real person is being used, users are suggested to get consent from the concerned person and clearly mention that it is a deepfake when posting content online. Developers of this software will not be responsible for actions of end-users. +**Please do not apply it to illegal and unethical scenarios.** + +In the event of violation of the legal and ethical requirements of the user's country or region, this code repository is exempt from liability + +### Installation + +Please refer to the [wiki](https://github.com/C0untFloyd/roop-unleashed/wiki). + + + + +### Usage + +- Windows: run the `windows_run.bat` from the Installer. +- Linux: `python run.py` + + + Open In Colab + + + +Additional commandline arguments are currently unsupported and settings should be done via the UI. + +> Note: When you run this program for the first time, it will download some models roughly ~2Gb in size. + + + + +### Changelog + +**22.04.2024** v3.9.0 + +- Bugfix: Face detection bounding box corrupt values at weird angles +- Rewrote mask previewing to work with every model +- Switching mask engines toggles text interactivity +- Clearing target files, resets face selection dropdown +- Massive rewrite of swapping architecture, needed for xseg implementation +- Added DFL Xseg Support for partial face occlusion +- Face masking only runs when there is a face detected +- Removed unnecessary toggle checkbox for text masking + + +**22.03.2024** v3.6.5 + +- Bugfix: Installer pulling latest update on first installation +- Bugfix: Regression issue, blurring/erosion missing from face swap +- Exposed erosion and blur amounts to UI +- Using same values for manual masking too + + +**20.03.2024** v3.6.3 + +- Bugfix: Workaround for Gradio Slider Change Bug +- Bugfix: CSS Styling to fix Gradio Image Height Bug +- Made face swapping mask offsets resolution independant +- Show offset mask as overlay +- Changed layout for masking + + +**18.03.2024** v3.6.0 + +- Updated to Gradio 4.21.0 - requiring many changes under the hood +- New manual masking (draw the mask yourself) +- Extras Tab, streamlined cutting/joining videos +- Re-added face selection by gender (on-demand loading, default turned off) +- Removed unnecessary activate live-cam option +- Added time info to preview frame and changed frame slider event to allow faster changes + + +**10.03.2024** v3.5.5 + +- Bugfix: Installer Path Env +- Bugfix: file attributes +- Video processing checks for presence of ffmpeg and displays warning if not found +- Removed gender + age detection to speed up processing. Option removed from UI +- Replaced restoreformer with restoreformer++ +- Live Cam recoded to run separate from virtual cam and without blocking controls +- Swapping with only 1 target face allows selecting from several input faces + + + +**08.01.2024** v3.5.0 + +- Bugfix: wrong access options when creating folders +- New auto rotation of horizontal faces, fixing bad landmark positions (expanded on ![PR 364](https://github.com/C0untFloyd/roop-unleashed/pull/364)) +- Simple VR Option for stereo Images/Movies, best used in selected face mode +- Added RestoreFormer Enhancer - https://github.com/wzhouxiff/RestoreFormer +- Bumped up package versions for onnx/Torch etc. + + +**16.10.2023** v3.3.4 + +**11.8.2023** v2.7.0 + +Initial Gradio Version - old TkInter Version now deprecated + +- Re-added unified padding to face enhancers +- Fixed DMDNet for all resolutions +- Selecting target face now automatically switches swapping mode to selected +- GPU providers are correctly set using the GUI (needs restart currently) +- Local output folder can be opened from page +- Unfinished extras functions disabled for now +- Installer checks out specific commit, allowing to go back to first install +- Updated readme for new gradio version +- Updated Colab + + +# Acknowledgements + +Lots of ideas, code or pre-trained models borrowed from the following projects: + +https://github.com/deepinsight/insightface
+https://github.com/s0md3v/roop
+https://github.com/AUTOMATIC1111/stable-diffusion-webui
+https://github.com/Hillobar/Rope
+https://github.com/TencentARC/GFPGAN
+https://github.com/kadirnar/codeformer-pip
+https://github.com/csxmli2016/DMDNet
+https://github.com/glucauze/sd-webui-faceswaplab
+https://github.com/ykk648/face_power
+ +
+
+Thanks to all developers! + diff --git a/roop-unleashed/__pycache__/settings.cpython-310.pyc b/roop-unleashed/__pycache__/settings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..859cc39fb90e23f990991b404857cde0f968c745 Binary files /dev/null and b/roop-unleashed/__pycache__/settings.cpython-310.pyc differ diff --git a/roop-unleashed/clip/__init__.py b/roop-unleashed/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/roop-unleashed/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/roop-unleashed/clip/bpe_simple_vocab_16e6.txt.gz b/roop-unleashed/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/roop-unleashed/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/roop-unleashed/clip/clip.py b/roop-unleashed/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..f983b7b35a19634bfc941733ab24d69b132ebeac --- /dev/null +++ b/roop-unleashed/clip/clip.py @@ -0,0 +1,241 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + #if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + # result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + #else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/roop-unleashed/clip/clipseg.py b/roop-unleashed/clip/clipseg.py new file mode 100644 index 0000000000000000000000000000000000000000..6adc7e4893cbb2bff31eb822dacf96a7c9a87e27 --- /dev/null +++ b/roop-unleashed/clip/clipseg.py @@ -0,0 +1,538 @@ +import math +from os.path import basename, dirname, join, isfile +import torch +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules.activation import ReLU + + +def get_prompt_list(prompt): + if prompt == 'plain': + return ['{}'] + elif prompt == 'fixed': + return ['a photo of a {}.'] + elif prompt == 'shuffle': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + else: + raise ValueError('Invalid value for prompt') + + +def forward_multihead_attention(x, b, with_aff=False, attn_mask=None): + """ + Simplified version of multihead attention (taken from torch source code but without tons of if clauses). + The mlp and layer norm come from CLIP. + x: input. + b: multihead attention module. + """ + + x_ = b.ln_1(x) + q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1) + tgt_len, bsz, embed_dim = q.size() + + head_dim = embed_dim // b.attn.num_heads + scaling = float(head_dim) ** -0.5 + + q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + + q = q * scaling + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2 + if attn_mask is not None: + + + attn_mask_type, attn_mask = attn_mask + n_heads = attn_output_weights.size(0) // attn_mask.size(0) + attn_mask = attn_mask.repeat(n_heads, 1) + + if attn_mask_type == 'cls_token': + # the mask only affects similarities compared to the readout-token. + attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...] + # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0] + + if attn_mask_type == 'all': + # print(attn_output_weights.shape, attn_mask[:, None].shape) + attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None] + + + attn_output_weights = torch.softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = b.attn.out_proj(attn_output) + + x = x + attn_output + x = x + b.mlp(b.ln_2(x)) + + if with_aff: + return x, attn_output_weights + else: + return x + + +class CLIPDenseBase(nn.Module): + + def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens): + super().__init__() + + import clip + + # prec = torch.FloatTensor + self.clip_model, _ = clip.load(version, device='cpu', jit=False) + self.model = self.clip_model.visual + + # if not None, scale conv weights such that we obtain n_tokens. + self.n_tokens = n_tokens + + for p in self.clip_model.parameters(): + p.requires_grad_(False) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + self.reduce = nn.Linear(768, reduce_dim) + + self.prompt_list = get_prompt_list(prompt) + + # precomputed prompts + import pickle + if isfile('precomputed_prompt_vectors.pickle'): + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + else: + self.precomputed_prompts = dict() + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + + with torch.no_grad(): + + inp_size = x_inp.shape[2:] + + if self.n_tokens is not None: + stride2 = x_inp.shape[2] // self.n_tokens + conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True) + x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation) + else: + x = self.model.conv1(x_inp) # shape = [*, width, grid, grid] + + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + + standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197 + + if x.shape[1] != standard_n_tokens: + new_shape = int(math.sqrt(x.shape[1]-1)) + x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:] + else: + x = x + self.model.positional_embedding.to(x.dtype) + + x = self.model.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + activations, affinities = [], [] + for i, res_block in enumerate(self.model.transformer.resblocks): + + if mask is not None: + mask_layer, mask_type, mask_tensor = mask + if mask_layer == i or mask_layer == 'all': + # import ipdb; ipdb.set_trace() + size = int(math.sqrt(x.shape[0] - 1)) + + attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size)) + + else: + attn_mask = None + else: + attn_mask = None + + x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask) + + if i in extract_layers: + affinities += [aff_per_head] + + #if self.n_tokens is not None: + # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)] + #else: + activations += [x] + + if len(extract_layers) > 0 and i == max(extract_layers) and skip: + print('early skip') + break + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_post(x[:, 0, :]) + + if self.model.proj is not None: + x = x @ self.model.proj + + return x, activations, affinities + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + if self.shift_vector is not None: + return cond + self.shift_vector + else: + return cond + + +def clip_load_untrained(version): + assert version == 'ViT-B/16' + from clip.model import CLIP + from clip.clip import _MODELS, _download + model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval() + state_dict = model.state_dict() + + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) + + +class CLIPDensePredT(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False, + add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None, complex_trans_conv=False): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + self.rev_activations = rev_activations + + depth = len(extract_layers) + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + self.version = version + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + if fix_shift: + # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False) + self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False) + # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False) + else: + self.shift_vector = None + + if trans_conv is None: + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + else: + # explicitly define transposed conv kernel size + trans_conv_ks = (trans_conv, trans_conv) + + if not complex_trans_conv: + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + else: + assert trans_conv_ks[0] == trans_conv_ks[1] + + tp_kernels = (trans_conv_ks[0] // 4, trans_conv_ks[0] // 4) + + self.trans_conv = nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(reduce_dim, reduce_dim // 2, kernel_size=tp_kernels[0], stride=tp_kernels[0]), + nn.ReLU(), + nn.ConvTranspose2d(reduce_dim // 2, 1, kernel_size=tp_kernels[1], stride=tp_kernels[1]), + ) + +# self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + self.prompt_list = get_prompt_list(prompt) + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + _activations = activations[::-1] if not self.rev_activations else activations + + a = None + for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + a = self.trans_conv(a) + + if self.n_tokens is not None: + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, + + + +class CLIPDensePredTMasked(CLIPDensePredT): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, + prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False, + refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None): + + super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim, + n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond, + fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only, + limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration, + n_tokens=n_tokens) + + def visual_forward_masked(self, img_s, seg_s): + return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s)) + + def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False): + + if seg_s is None: + cond = cond_or_img_s + else: + img_s = cond_or_img_s + + with torch.no_grad(): + cond, _, _ = self.visual_forward_masked(img_s, seg_s) + + return super().forward(img_q, cond, return_features=return_features) + + + +class CLIPDenseBaseline(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', cond_layer=0, + extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed', + reduce_cond=None, limit_to_clip_only=False, n_tokens=None): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + device = 'cpu' + + # self.cond_layer = cond_layer + self.extract_layer = extract_layer + self.limit_to_clip_only = limit_to_clip_only + self.shift_vector = None + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + assert reduce2_dim is not None + + self.reduce2 = nn.Sequential( + nn.Linear(reduce_dim, reduce2_dim), + nn.ReLU(), + nn.Linear(reduce2_dim, reduce_dim) + ) + + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + + def forward(self, inp_image, conditional=None, return_features=False): + + inp_image = inp_image.to(self.model.positional_embedding.device) + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer]) + + a = activations[0] + a = self.reduce(a) + a = self.film_mul(cond) * a + self.film_add(cond) + + if self.reduce2 is not None: + a = self.reduce2(a) + + # the original model would execute a transformer block here + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + a = self.trans_conv(a) + + if return_features: + return a, visual_q, cond, activations + else: + return a, + + +class CLIPSegMultiLabel(nn.Module): + + def __init__(self, model) -> None: + super().__init__() + + from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC + + self.pascal_classes = VOC + + from clip.clipseg import CLIPDensePredT + from general_utils import load_model + # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False) + self.clipseg = load_model(model, strict=False) + + self.clipseg.eval() + + def forward(self, x): + + bs = x.shape[0] + out = torch.ones(21, bs, 352, 352).to(x.device) * -10 + + for class_id, class_name in enumerate(self.pascal_classes): + + fac = 3 if class_name == 'background' else 1 + + with torch.no_grad(): + pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac + + out[class_id] += pred + + + out = out.permute(1, 0, 2, 3) + + return out + + # construct output tensor + diff --git a/roop-unleashed/clip/model.py b/roop-unleashed/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..232b7792eb97440642547bd462cf128df9243933 --- /dev/null +++ b/roop-unleashed/clip/model.py @@ -0,0 +1,436 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/roop-unleashed/clip/simple_tokenizer.py b/roop-unleashed/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/roop-unleashed/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/roop-unleashed/clip/vitseg.py b/roop-unleashed/clip/vitseg.py new file mode 100644 index 0000000000000000000000000000000000000000..ed621431ddf930fcfa27b5929999776b96fede63 --- /dev/null +++ b/roop-unleashed/clip/vitseg.py @@ -0,0 +1,286 @@ +import math +from posixpath import basename, dirname, join +# import clip +from clip.model import convert_weights +import torch +import json +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules import activation +from torch.nn.modules.activation import ReLU +from torchvision import transforms + +normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + +from torchvision.models import ResNet + + +def process_prompts(conditional, prompt_list, conditional_map): + # DEPRECATED + + # randomly sample a synonym + words = [conditional_map[int(i)] for i in conditional] + words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] + words = [w.replace('_', ' ') for w in words] + + if prompt_list is not None: + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + else: + prompts = ['a photo of {}'] * (len(words)) + + return [promt.format(w) for promt, w in zip(prompts, words)] + + +class VITDenseBase(nn.Module): + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + with torch.no_grad(): + + x_inp = nnf.interpolate(x_inp, (384, 384)) + + x = self.model.patch_embed(x_inp) + cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.model.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.model.pos_drop(x + self.model.pos_embed) + + activations = [] + for i, block in enumerate(self.model.blocks): + x = block(x) + + if i in extract_layers: + # permute to be compatible with CLIP + activations += [x.permute(1,0,2)] + + x = self.model.norm(x) + x = self.model.head(self.model.pre_logits(x[:, 0])) + + # again for CLIP compatibility + # x = x.permute(1, 0, 2) + + return x, activations, None + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + return cond + + +class VITDensePredT(VITDenseBase): + + def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, + add_calibration=False, process_cond=None, not_pretrained=False): + super().__init__() + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + import timm + self.model = timm.create_model('vit_base_patch16_384', pretrained=True) + self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) + + for p in self.model.parameters(): + p.requires_grad_(False) + + import clip + self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) + # del self.clip_model.visual + + + self.token_shape = (14, 14) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + # self.film = AVAILABLE_BLOCKS['film'](512, 128) + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + # DEPRECATED + # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + trans_conv_ks = (16, 16) + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + if prompt == 'fixed': + self.prompt_list = ['a photo of a {}.'] + elif prompt == 'shuffle': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + elif prompt == 'shuffle_clip': + from models.clip_prompts import imagenet_templates + self.prompt_list = imagenet_templates + + if process_cond is not None: + if process_cond == 'clamp' or process_cond[0] == 'clamp': + + val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 + + def clamp_vec(x): + return torch.clamp(x, -val, val) + + self.process_cond = clamp_vec + + elif process_cond.endswith('.pth'): + + shift = torch.load(process_cond) + def add_shift(x): + return x + shift.to(x.device) + + self.process_cond = add_shift + + import pickle + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + # inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + inp_image_size = inp_image.shape[2:] + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + a = None + for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + if self.trans_conv is not None: + a = self.trans_conv(a) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + a = nnf.interpolate(a, inp_image_size) + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, diff --git a/roop-unleashed/config_colab.yaml b/roop-unleashed/config_colab.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c47f3f6f17f35eeb2089e8aba2ff42c80077ba5 --- /dev/null +++ b/roop-unleashed/config_colab.yaml @@ -0,0 +1,14 @@ +clear_output: true +force_cpu: false +max_threads: 3 +memory_limit: 0 +output_image_format: png +output_template: '{file}_{time}' +output_video_codec: libx264 +output_video_format: mp4 +provider: cuda +selected_theme: Default +server_name: '' +server_port: 0 +server_share: true +video_quality: 14 diff --git a/roop-unleashed/installer/installer.py b/roop-unleashed/installer/installer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab45c18c2288f85d6c25de2923e2d42d561a15b7 --- /dev/null +++ b/roop-unleashed/installer/installer.py @@ -0,0 +1,87 @@ +import argparse +import glob +import os +import shutil +import site +import subprocess +import sys + + +script_dir = os.getcwd() + + +def run_cmd(cmd, capture_output=False, env=None): + # Run shell commands + return subprocess.run(cmd, shell=True, capture_output=capture_output, env=env) + + +def check_env(): + # If we have access to conda, we are probably in an environment + conda_not_exist = run_cmd("conda", capture_output=True).returncode + if conda_not_exist: + print("Conda is not installed. Exiting...") + sys.exit() + + # Ensure this is a new environment and not the base environment + if os.environ["CONDA_DEFAULT_ENV"] == "base": + print("Create an environment for this project and activate it. Exiting...") + sys.exit() + + +def install_dependencies(): + global MY_PATH + + # Install Git and clone repo + run_cmd("conda install -y -k git") + run_cmd("git clone https://github.com/C0untFloyd/roop-unleashed.git") + os.chdir(MY_PATH) + run_cmd("git checkout c8643a0532f09f84397aaacf526e66db6455d399") + # Installs dependencies from requirements.txt + run_cmd("python -m pip install -r requirements.txt") + + + +def update_dependencies(): + global MY_PATH + + os.chdir(MY_PATH) + # do a hard reset for to update even if there are local changes + run_cmd("git fetch --all") + run_cmd("git reset --hard origin/main") + run_cmd("git pull") + # Installs/Updates dependencies from all requirements.txt + run_cmd("python -m pip install -r requirements.txt") + + +def start_app(): + global MY_PATH + + os.chdir(MY_PATH) + # forward commandline arguments + sys.argv.pop(0) + args = ' '.join(sys.argv) + print("Launching App") + run_cmd(f'python run.py {args}') + + +if __name__ == "__main__": + global MY_PATH + + MY_PATH = "roop-unleashed" + + + # Verifies we are in a conda environment + check_env() + + # If webui has already been installed, skip and run + if not os.path.exists(MY_PATH): + install_dependencies() + else: + # moved update from batch to here, because of batch limitations + updatechoice = input("Check for Updates? [y/n]").lower() + if updatechoice == "y": + update_dependencies() + + # Run the model with webui + os.chdir(script_dir) + start_app() diff --git a/roop-unleashed/installer/windows_run.bat b/roop-unleashed/installer/windows_run.bat new file mode 100644 index 0000000000000000000000000000000000000000..5441a00d9b98d305caffe4c2391c09f371e58c4c --- /dev/null +++ b/roop-unleashed/installer/windows_run.bat @@ -0,0 +1,99 @@ +@echo off + +REM No CLI arguments supported anymore +set COMMANDLINE_ARGS= + +cd /D "%~dp0" + +echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniconda which can not be silently installed under a path with spaces. && goto end + +set PATH=%PATH%;%SystemRoot%\system32 + +@rem config +set INSTALL_DIR=%cd%\installer_files +set CONDA_ROOT_PREFIX=%cd%\installer_files\conda +set INSTALL_ENV_DIR=%cd%\installer_files\env +set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe +set FFMPEG_DOWNLOAD_URL=https://github.com/GyanD/codexffmpeg/releases/download/2023-06-21-git-1bcb8a7338/ffmpeg-2023-06-21-git-1bcb8a7338-essentials_build.zip +set INSTALL_FFMPEG_DIR=%cd%\installer_files\ffmpeg +set INSIGHTFACE_PACKAGE_URL=https://github.com/C0untFloyd/roop-unleashed/releases/download/3.6.6/insightface-0.7.3-cp310-cp310-win_amd64.whl +set INSIGHTFACE_PACKAGE_PATH=%INSTALL_DIR%\insightface-0.7.3-cp310-cp310-win_amd64.whl + +set conda_exists=F +set ffmpeg_exists=F + +@rem figure out whether git and conda needs to be installed +call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1 +if "%ERRORLEVEL%" EQU "0" set conda_exists=T + +@rem Check if FFmpeg is already in PATH +where ffmpeg >nul 2>&1 +if "%ERRORLEVEL%" EQU "0" ( + echo FFmpeg is already installed. + set ffmpeg_exists=T +) + +@rem (if necessary) install git and conda into a contained environment + +@rem download conda +if "%conda_exists%" == "F" ( + echo Downloading Miniconda from %MINICONDA_DOWNLOAD_URL% to %INSTALL_DIR%\miniconda_installer.exe + mkdir "%INSTALL_DIR%" + call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe" || ( echo. && echo Miniconda failed to download. && goto end ) + echo Installing Miniconda to %CONDA_ROOT_PREFIX% + start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX% + + @rem test the conda binary + echo Miniconda version: + call "%CONDA_ROOT_PREFIX%\_conda.exe" --version || ( echo. && echo Miniconda not found. && goto end ) +) + +@rem create the installer env +if not exist "%INSTALL_ENV_DIR%" ( + echo Creating Conda Environment + call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10 || ( echo. && echo ERROR: Conda environment creation failed. && goto end ) + @rem check if conda environment was actually created + if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo ERROR: Conda environment is empty. && goto end ) + @rem activate installer env + call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo ERROR: Miniconda hook not found. && goto end ) + @rem Download insightface package + echo Downloading insightface package from %INSIGHTFACE_PACKAGE_URL% to %INSIGHTFACE_PACKAGE_PATH% + call curl -Lk "%INSIGHTFACE_PACKAGE_URL%" > "%INSIGHTFACE_PACKAGE_PATH%" || ( echo. && echo ERROR: Insightface package failed to download. && goto end ) + @rem install insightface package using pip + echo Installing insightface package + call pip install "%INSIGHTFACE_PACKAGE_PATH%" || ( echo. && echo ERROR: Insightface package installation failed. && goto end ) +) + +@rem Download and install FFmpeg if not already installed +if "%ffmpeg_exists%" == "F" ( + if not exist "%INSTALL_FFMPEG_DIR%" ( + echo Downloading ffmpeg from %FFMPEG_DOWNLOAD_URL% to %INSTALL_DIR% + call curl -Lk "%FFMPEG_DOWNLOAD_URL%" > "%INSTALL_DIR%\ffmpeg.zip" || ( echo. && echo ffmpeg failed to download. && goto end ) + call powershell -command "Expand-Archive -Force '%INSTALL_DIR%\ffmpeg.zip' '%INSTALL_DIR%\'" + cd "installer_files" + setlocal EnableExtensions EnableDelayedExpansion + for /f "tokens=*" %%f in ('dir /s /b /ad "ffmpeg\*"') do ( + ren "%%f" "ffmpeg" + ) + endlocal + setx PATH "%INSTALL_FFMPEG_DIR%\bin\;%PATH%" + echo To use videos, you need to restart roop after this installation. + cd .. + ) +) else ( + echo Skipping FFmpeg installation as it is already available. +) + +@rem setup installer env +@rem check if conda environment was actually created +if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo ERROR: Conda environment is empty. && goto end ) +@rem activate installer env +call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo ERROR: Miniconda hook not found. && goto end ) +echo Launching roop unleashed +call python installer.py %COMMANDLINE_ARGS% + +echo. +echo Done! + +:end +pause diff --git a/roop-unleashed/models/CLIP/rd64-uni-refined.pth b/roop-unleashed/models/CLIP/rd64-uni-refined.pth new file mode 100644 index 0000000000000000000000000000000000000000..1004abde5a060f41b188410756adb7cc3ea379ea --- /dev/null +++ b/roop-unleashed/models/CLIP/rd64-uni-refined.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4956f9a7978a75630b08c9d6ec075b7c51cf43b4751b686e3a011d4012ddc9d +size 4720707 diff --git a/roop-unleashed/models/CodeFormer/CodeFormerv0.1.onnx b/roop-unleashed/models/CodeFormer/CodeFormerv0.1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6368465c9df3b6e698faec3b47793bb7e602e0e2 --- /dev/null +++ b/roop-unleashed/models/CodeFormer/CodeFormerv0.1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9aa48fc4b21224d85784c9a58885201284ec8e590b988126db2c07495b421d36 +size 376821951 diff --git a/roop-unleashed/models/DMDNet.pth b/roop-unleashed/models/DMDNet.pth new file mode 100644 index 0000000000000000000000000000000000000000..969651e65d2a14acd530f394a12b2675edbc742c --- /dev/null +++ b/roop-unleashed/models/DMDNet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70daeb4b1fd10f241043b587d892a941f2651d7322db02f06ff64b166537f65c +size 603684323 diff --git a/roop-unleashed/models/Frame/deoldify_artistic.onnx b/roop-unleashed/models/Frame/deoldify_artistic.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05aa08f4872ea5bd6830e9b5ec7a5d23982c923b --- /dev/null +++ b/roop-unleashed/models/Frame/deoldify_artistic.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be026e17c47c85527b3084cacad352f7ca0e021c33aa827062c5997ebe72c61f +size 255024891 diff --git a/roop-unleashed/models/Frame/deoldify_stable.onnx b/roop-unleashed/models/Frame/deoldify_stable.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa36af257e550ba1fbaf613af40b9113b505f7b9 --- /dev/null +++ b/roop-unleashed/models/Frame/deoldify_stable.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98d69dbecde018fe3d630a35ac850ac590b23e359c8349d8404b467bbfe4a0b9 +size 873359997 diff --git a/roop-unleashed/models/Frame/isnet-general-use.onnx b/roop-unleashed/models/Frame/isnet-general-use.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aae8625d60df68f7a2c7fa770814a3e6eb30612a --- /dev/null +++ b/roop-unleashed/models/Frame/isnet-general-use.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60920e99c45464f2ba57bee2ad08c919a52bbf852739e96947fbb4358c0d964a +size 178648008 diff --git a/roop-unleashed/models/Frame/lsdir_x4.onnx b/roop-unleashed/models/Frame/lsdir_x4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4138f44f53bd574e61fea87dbaa5aa0a8851617c --- /dev/null +++ b/roop-unleashed/models/Frame/lsdir_x4.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c0073607ab48e91a56a180f6928597362fef9f0924cc91325aab8ce8cf1032c +size 66938051 diff --git a/roop-unleashed/models/Frame/real_esrgan_x2.onnx b/roop-unleashed/models/Frame/real_esrgan_x2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e8ffe3dfe6111fd51e8206cdf8c1cc986702f1bb --- /dev/null +++ b/roop-unleashed/models/Frame/real_esrgan_x2.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28e6925e46301ba7a4cbfeaae41cefe043dd5941423094a3db8b176d837bf1dd +size 69524246 diff --git a/roop-unleashed/models/Frame/real_esrgan_x4.onnx b/roop-unleashed/models/Frame/real_esrgan_x4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..21b902c065cdbfea1c30e7f457c1e2438b675ccc --- /dev/null +++ b/roop-unleashed/models/Frame/real_esrgan_x4.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4139cc1585d04851ccd41570b0f76e775c96e064ca292d5372b6031704dda0d3 +size 69464831 diff --git a/roop-unleashed/models/GFPGANv1.4.onnx b/roop-unleashed/models/GFPGANv1.4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..70ce511a0f7073017ee8d1a12dec525047ae358f --- /dev/null +++ b/roop-unleashed/models/GFPGANv1.4.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5060d6c8d84851bbb8da630bea59b56414b49923a2b9304fb08f72d4c98f0aeb +size 340256688 diff --git a/roop-unleashed/models/GPEN-BFR-512.onnx b/roop-unleashed/models/GPEN-BFR-512.onnx new file mode 100644 index 0000000000000000000000000000000000000000..13eec093eb1b7133416a5bc9959b1fdc07987ba2 --- /dev/null +++ b/roop-unleashed/models/GPEN-BFR-512.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0960f836488735444d508b588e44fb5dfd19c68fde9163ad7878aa24d1d5115e +size 284250449 diff --git a/roop-unleashed/models/buffalo_l.zip b/roop-unleashed/models/buffalo_l.zip new file mode 100644 index 0000000000000000000000000000000000000000..3c8de83bd164f9ac2abc7eb4486ff01b1fd3af4f --- /dev/null +++ b/roop-unleashed/models/buffalo_l.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80ffe37d8a5940d59a7384c201a2a38d4741f2f3c51eef46ebb28218a7b0ca2f +size 288621354 diff --git a/roop-unleashed/models/buffalo_l/1k3d68.onnx b/roop-unleashed/models/buffalo_l/1k3d68.onnx new file mode 100644 index 0000000000000000000000000000000000000000..221aa2f02a6faccddb2723529e1f93c7db2edbdc --- /dev/null +++ b/roop-unleashed/models/buffalo_l/1k3d68.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc +size 143607619 diff --git a/roop-unleashed/models/buffalo_l/2d106det.onnx b/roop-unleashed/models/buffalo_l/2d106det.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cdb163d88b5f51396855ebc795e0114322c98b6b --- /dev/null +++ b/roop-unleashed/models/buffalo_l/2d106det.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +size 5030888 diff --git a/roop-unleashed/models/buffalo_l/det_10g.onnx b/roop-unleashed/models/buffalo_l/det_10g.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa586e034379fa5ea5babc8aa73d47afcd0fa6c2 --- /dev/null +++ b/roop-unleashed/models/buffalo_l/det_10g.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +size 16923827 diff --git a/roop-unleashed/models/buffalo_l/genderage.onnx b/roop-unleashed/models/buffalo_l/genderage.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fcf638481cea978e99ddabd914ccd3b70c8401cb --- /dev/null +++ b/roop-unleashed/models/buffalo_l/genderage.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +size 1322532 diff --git a/roop-unleashed/models/buffalo_l/w600k_r50.onnx b/roop-unleashed/models/buffalo_l/w600k_r50.onnx new file mode 100644 index 0000000000000000000000000000000000000000..571d2bb9ffd76399b23260620b9101b20bcc4e99 --- /dev/null +++ b/roop-unleashed/models/buffalo_l/w600k_r50.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43 +size 174383860 diff --git a/roop-unleashed/models/inswapper_128.onnx b/roop-unleashed/models/inswapper_128.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cb672b799d74fdf7ab8b172a1b1d78411f6400f5 --- /dev/null +++ b/roop-unleashed/models/inswapper_128.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af +size 554253681 diff --git a/roop-unleashed/models/restoreformer_plus_plus.onnx b/roop-unleashed/models/restoreformer_plus_plus.onnx new file mode 100644 index 0000000000000000000000000000000000000000..54dbba7932a864c0820b036a045ea0774f5370de --- /dev/null +++ b/roop-unleashed/models/restoreformer_plus_plus.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4db5a89902b6a2d452446f5721245a6f7185f699b6aec7b77285adb4d504337 +size 294264812 diff --git a/roop-unleashed/models/xseg.onnx b/roop-unleashed/models/xseg.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6d6d3e341bb1194d5a3da18b776ceec79d455869 --- /dev/null +++ b/roop-unleashed/models/xseg.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b57328efcb839d85973164b617ceee9dfe6cfcb2c82e8a033bba9f4f09b27e5 +size 70327737 diff --git a/roop-unleashed/mypy.ini b/roop-unleashed/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..64218bc23688632a08c98ec4a0451ed46f8ed5e5 --- /dev/null +++ b/roop-unleashed/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +check_untyped_defs = True +disallow_any_generics = True +disallow_untyped_calls = True +disallow_untyped_defs = True +ignore_missing_imports = True +strict_optional = False diff --git a/roop-unleashed/requirements.txt b/roop-unleashed/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..26d424de803dc18fa918ff2501bf6369e6a6eb1e --- /dev/null +++ b/roop-unleashed/requirements.txt @@ -0,0 +1,19 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 + +numpy==1.26.4 +gradio==4.29.0 +opencv-python==4.9.0.80 +onnx==1.16.0 +insightface==0.7.3 +psutil==5.9.6 +torch==2.1.2+cu118; sys_platform != 'darwin' +torch==2.1.2; sys_platform == 'darwin' +torchvision==0.16.2+cu118; sys_platform != 'darwin' +torchvision==0.16.2; sys_platform == 'darwin' +onnxruntime==1.17.1; sys_platform == 'darwin' and platform_machine != 'arm64' +onnxruntime-silicon==1.17.1; sys_platform == 'darwin' and platform_machine == 'arm64' +onnxruntime-gpu==1.17.1; sys_platform != 'darwin' +tqdm==4.66.4 +ftfy +regex +pyvirtualcam diff --git a/roop-unleashed/roop-unleashed.ipynb b/roop-unleashed/roop-unleashed.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5bfa7b67796c0d150d83b4674e804e4491116bb0 --- /dev/null +++ b/roop-unleashed/roop-unleashed.ipynb @@ -0,0 +1,208 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "collapsed_sections": [ + "UdQ1VHdI8lCf" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Colab for roop-unleashed - Gradio version\n", + "https://github.com/C0untFloyd/roop-unleashed\n" + ], + "metadata": { + "id": "G9BdiCppV6AS" + } + }, + { + "cell_type": "markdown", + "source": [ + "Install CUDA V11.8 on Google Cloud Compute" + ], + "metadata": { + "id": "CanIXgLJgaOj" + } + }, + { + "cell_type": "code", + "source": [ + "!apt-get -y update\n", + "!apt-get -y install cuda-toolkit-11-8\n", + "import os\n", + "os.environ[\"LD_LIBRARY_PATH\"] += \":\" + \"/usr/local/cuda-11/lib64\"\n", + "os.environ[\"LD_LIBRARY_PATH\"] += \":\" + \"/usr/local/cuda-11.8/lib64\"" + ], + "metadata": { + "id": "96GE4UgYg3Ej" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Installing & preparing requirements" + ], + "metadata": { + "id": "0ZYRNb0AWLLW" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t1yPuhdySqCq" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/C0untFloyd/roop-unleashed.git\n", + "%cd roop-unleashed\n", + "!mv config_colab.yaml config.yaml\n", + "!pip install pip install -r requirements.txt" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Running roop-unleashed with default config" + ], + "metadata": { + "id": "u_4JQiSlV9Fi" + } + }, + { + "cell_type": "code", + "source": [ + "!python run.py" + ], + "metadata": { + "id": "Is6U2huqSzLE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Download generated images folder\n", + "(only needed if you want to zip the generated output)" + ], + "metadata": { + "id": "UdQ1VHdI8lCf" + } + }, + { + "cell_type": "code", + "source": [ + "import shutil\n", + "import os\n", + "from google.colab import files\n", + "\n", + "def zip_directory(directory_path, zip_path):\n", + " shutil.make_archive(zip_path, 'zip', directory_path)\n", + "\n", + "# Set the directory path you want to download\n", + "directory_path = '/content/roop-unleashed/output'\n", + "\n", + "# Set the zip file name\n", + "zip_filename = 'fake_output.zip'\n", + "\n", + "# Zip the directory\n", + "zip_directory(directory_path, zip_filename)\n", + "\n", + "# Download the zip file\n", + "files.download(zip_filename+'.zip')\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "oYjWveAmw10X", + "outputId": "5b4c3650-f951-434a-c650-5525a8a70c1e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "download(\"download_789eab11-93d2-4880-adf3-6aceee0cc5f9\", \"fake_output.zip.zip\", 80125)" + ] + }, + "metadata": {} + } + ] + } + ] +} diff --git a/roop-unleashed/roop/FaceSet.py b/roop-unleashed/roop/FaceSet.py new file mode 100644 index 0000000000000000000000000000000000000000..9e426219fe4265290883a026fbde2d0513d5d554 --- /dev/null +++ b/roop-unleashed/roop/FaceSet.py @@ -0,0 +1,20 @@ +import numpy as np + +class FaceSet: + faces = [] + ref_images = [] + embedding_average = 'None' + embeddings_backup = None + + def __init__(self): + self.faces = [] + self.ref_images = [] + self.embeddings_backup = None + + def AverageEmbeddings(self): + if len(self.faces) > 1 and self.embeddings_backup is None: + self.embeddings_backup = self.faces[0]['embedding'] + embeddings = [face.embedding for face in self.faces] + + self.faces[0]['embedding'] = np.mean(embeddings, axis=0) + # try median too? diff --git a/roop-unleashed/roop/ProcessEntry.py b/roop-unleashed/roop/ProcessEntry.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd53239463a14769954a10f1371d332bd88e05d --- /dev/null +++ b/roop-unleashed/roop/ProcessEntry.py @@ -0,0 +1,7 @@ +class ProcessEntry: + def __init__(self, filename: str, start: int, end: int, fps: float): + self.filename = filename + self.finalname = None + self.startframe = start + self.endframe = end + self.fps = fps \ No newline at end of file diff --git a/roop-unleashed/roop/ProcessMgr.py b/roop-unleashed/roop/ProcessMgr.py new file mode 100644 index 0000000000000000000000000000000000000000..285089389f1ec71c0b64a18098eb9657cb3fabb1 --- /dev/null +++ b/roop-unleashed/roop/ProcessMgr.py @@ -0,0 +1,701 @@ +import os +import cv2 +import numpy as np +import psutil + +from enum import Enum +from roop.ProcessOptions import ProcessOptions + +from roop.face_util import get_first_face, get_all_faces, rotate_image_180, rotate_anticlockwise, rotate_clockwise, clamp_cut_values +from roop.utilities import compute_cosine_distance, get_device, str_to_class +import roop.vr_util as vr + +from typing import Any, List, Callable +from roop.typing import Frame, Face +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Thread, Lock +from queue import Queue +from tqdm import tqdm +from roop.ffmpeg_writer import FFMPEG_VideoWriter +import roop.globals + +# Poor man's enum to be able to compare to int +class eNoFaceAction(): + USE_ORIGINAL_FRAME = 0 + RETRY_ROTATED = 1 + SKIP_FRAME = 2 + SKIP_FRAME_IF_DISSIMILAR = 3 + + + +def create_queue(temp_frame_paths: List[str]) -> Queue[str]: + queue: Queue[str] = Queue() + for frame_path in temp_frame_paths: + queue.put(frame_path) + return queue + + +def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]: + queues = [] + for _ in range(queue_per_future): + if not queue.empty(): + queues.append(queue.get()) + return queues + + +class ProcessMgr(): + input_face_datas = [] + target_face_datas = [] + + imagemask = None + + processors = [] + options : ProcessOptions = None + + num_threads = 1 + current_index = 0 + processing_threads = 1 + buffer_wait_time = 0.1 + + lock = Lock() + + frames_queue = None + processed_queue = None + + videowriter= None + + progress_gradio = None + total_frames = 0 + + + + + plugins = { + 'faceswap' : 'FaceSwapInsightFace', + 'mask_clip2seg' : 'Mask_Clip2Seg', + 'mask_xseg' : 'Mask_XSeg', + 'codeformer' : 'Enhance_CodeFormer', + 'gfpgan' : 'Enhance_GFPGAN', + 'dmdnet' : 'Enhance_DMDNet', + 'gpen' : 'Enhance_GPEN', + 'restoreformer++' : 'Enhance_RestoreFormerPPlus', + 'colorizer' : 'Frame_Colorizer', + 'filter_generic' : 'Frame_Filter', + 'removebg' : 'Frame_Masking', + 'upscale' : 'Frame_Upscale' + } + + def __init__(self, progress): + if progress is not None: + self.progress_gradio = progress + + def reuseOldProcessor(self, name:str): + for p in self.processors: + if p.processorname == name: + return p + + return None + + + def initialize(self, input_faces, target_faces, options): + self.input_face_datas = input_faces + self.target_face_datas = target_faces + self.options = options + devicename = get_device() + + roop.globals.g_desired_face_analysis=["landmark_3d_68", "landmark_2d_106","detection","recognition"] + if options.swap_mode == "all_female" or options.swap_mode == "all_male": + roop.globals.g_desired_face_analysis.append("genderage") + + for p in self.processors: + newp = next((x for x in options.processors.keys() if x == p.processorname), None) + if newp is None: + p.Release() + del p + + newprocessors = [] + for key, extoption in options.processors.items(): + p = self.reuseOldProcessor(key) + if p is None: + classname = self.plugins[key] + module = 'roop.processors.' + classname + p = str_to_class(module, classname) + if p is not None: + extoption.update({"devicename": devicename}) + p.Initialize(extoption) + newprocessors.append(p) + else: + print(f"Not using {module}") + self.processors = newprocessors + + + + if isinstance(self.options.imagemask, dict) and self.options.imagemask.get("layers") and len(self.options.imagemask["layers"]) > 0: + self.options.imagemask = self.options.imagemask.get("layers")[0] + # Get rid of alpha + self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_RGBA2GRAY) + if np.any(self.options.imagemask): + mo = self.input_face_datas[0].faces[0].mask_offsets + self.options.imagemask = self.blur_area(self.options.imagemask, mo[4], mo[5]) + self.options.imagemask = self.options.imagemask.astype(np.float32) / 255 + self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_GRAY2RGB) + else: + self.options.imagemask = None + + self.options.frame_processing = False + for p in self.processors: + if p.type.startswith("frame_"): + self.options.frame_processing = True + + + + + + + def run_batch(self, source_files, target_files, threads:int = 1): + progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' + self.total_frames = len(source_files) + self.num_threads = threads + with tqdm(total=self.total_frames, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: + with ThreadPoolExecutor(max_workers=threads) as executor: + futures = [] + queue = create_queue(source_files) + queue_per_future = max(len(source_files) // threads, 1) + while not queue.empty(): + future = executor.submit(self.process_frames, source_files, target_files, pick_queue(queue, queue_per_future), lambda: self.update_progress(progress)) + futures.append(future) + for future in as_completed(futures): + future.result() + + + def process_frames(self, source_files: List[str], target_files: List[str], current_files, update: Callable[[], None]) -> None: + for f in current_files: + if not roop.globals.processing: + return + + # Decode the byte array into an OpenCV image + temp_frame = cv2.imdecode(np.fromfile(f, dtype=np.uint8), cv2.IMREAD_COLOR) + if temp_frame is not None: + if self.options.frame_processing: + for p in self.processors: + frame = p.Run(temp_frame) + resimg = frame + else: + resimg = self.process_frame(temp_frame) + if resimg is not None: + i = source_files.index(f) + cv2.imwrite(target_files[i], resimg) + if update: + update() + + + + def read_frames_thread(self, cap, frame_start, frame_end, num_threads): + num_frame = 0 + total_num = frame_end - frame_start + if frame_start > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES,frame_start) + + while True and roop.globals.processing: + ret, frame = cap.read() + if not ret: + break + + self.frames_queue[num_frame % num_threads].put(frame, block=True) + num_frame += 1 + if num_frame == total_num: + break + + for i in range(num_threads): + self.frames_queue[i].put(None) + + + + def process_videoframes(self, threadindex, progress) -> None: + while True: + frame = self.frames_queue[threadindex].get() + if frame is None: + self.processing_threads -= 1 + self.processed_queue[threadindex].put((False, None)) + return + else: + if self.options.frame_processing: + for p in self.processors: + frame = p.Run(frame) + resimg = frame + else: + resimg = self.process_frame(frame) + self.processed_queue[threadindex].put((True, resimg)) + del frame + progress() + + + def write_frames_thread(self): + nextindex = 0 + num_producers = self.num_threads + + while True: + process, frame = self.processed_queue[nextindex % self.num_threads].get() + nextindex += 1 + if frame is not None: + self.videowriter.write_frame(frame) + del frame + elif process == False: + num_producers -= 1 + if num_producers < 1: + return + + + + def run_batch_inmem(self, source_video, target_video, frame_start, frame_end, fps, threads:int = 1, skip_audio=False): + cap = cv2.VideoCapture(source_video) + # frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_count = (frame_end - frame_start) + 1 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + processed_resolution = None + for p in self.processors: + if hasattr(p, 'getProcessedResolution'): + processed_resolution = p.getProcessedResolution(width, height) + print(f"Processed resolution: {processed_resolution}") + if processed_resolution is not None: + width = processed_resolution[0] + height = processed_resolution[1] + + + self.total_frames = frame_count + self.num_threads = threads + + self.processing_threads = self.num_threads + self.frames_queue = [] + self.processed_queue = [] + for _ in range(threads): + self.frames_queue.append(Queue(1)) + self.processed_queue.append(Queue(1)) + + self.videowriter = FFMPEG_VideoWriter(target_video, (width, height), fps, codec=roop.globals.video_encoder, crf=roop.globals.video_quality, audiofile=None) + + readthread = Thread(target=self.read_frames_thread, args=(cap, frame_start, frame_end, threads)) + readthread.start() + + writethread = Thread(target=self.write_frames_thread) + writethread.start() + + progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' + with tqdm(total=self.total_frames, desc='Processing', unit='frames', dynamic_ncols=True, bar_format=progress_bar_format) as progress: + with ThreadPoolExecutor(thread_name_prefix='swap_proc', max_workers=self.num_threads) as executor: + futures = [] + + for threadindex in range(threads): + future = executor.submit(self.process_videoframes, threadindex, lambda: self.update_progress(progress)) + futures.append(future) + + for future in as_completed(futures): + future.result() + # wait for the task to complete + readthread.join() + writethread.join() + cap.release() + self.videowriter.close() + self.frames_queue.clear() + self.processed_queue.clear() + + + + + def update_progress(self, progress: Any = None) -> None: + process = psutil.Process(os.getpid()) + memory_usage = process.memory_info().rss / 1024 / 1024 / 1024 + progress.set_postfix({ + 'memory_usage': '{:.2f}'.format(memory_usage).zfill(5) + 'GB', + 'execution_threads': self.num_threads + }) + progress.update(1) + if self.progress_gradio is not None: + self.progress_gradio((progress.n, self.total_frames), desc='Processing', total=self.total_frames, unit='frames') + + +# https://github.com/deepinsight/insightface#third-party-re-implementation-of-arcface +# https://github.com/deepinsight/insightface/blob/master/alignment/coordinate_reg/image_infer.py +# https://github.com/deepinsight/insightface/issues/1350 +# https://github.com/linghu8812/tensorrt_inference + + + def process_frame(self, frame:Frame): + if len(self.input_face_datas) < 1 and not self.options.show_face_masking: + return frame + temp_frame = frame.copy() + num_swapped, temp_frame = self.swap_faces(frame, temp_frame) + if num_swapped > 0: + if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME_IF_DISSIMILAR: + if len(self.input_face_datas) > num_swapped: + return None + return temp_frame + if roop.globals.no_face_action == eNoFaceAction.USE_ORIGINAL_FRAME: + return frame + if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME: + #This only works with in-mem processing, as it simply skips the frame. + #For 'extract frames' it simply leaves the unprocessed frame unprocessed and it gets used in the final output by ffmpeg. + #If we could delete that frame here, that'd work but that might cause ffmpeg to fail unless the frames are renamed, and I don't think we have the info on what frame it actually is????? + #alternatively, it could mark all the necessary frames for deletion, delete them at the end, then rename the remaining frames that might work? + return None + else: + return self.retry_rotated(frame) + + def retry_rotated(self, frame): + copyframe = frame.copy() + copyframe = rotate_clockwise(copyframe) + temp_frame = copyframe.copy() + num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame) + if num_swapped > 0: + return rotate_anticlockwise(temp_frame) + + copyframe = frame.copy() + copyframe = rotate_anticlockwise(copyframe) + temp_frame = copyframe.copy() + num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame) + if num_swapped > 0: + return rotate_clockwise(temp_frame) + del copyframe + return frame + + + + def swap_faces(self, frame, temp_frame): + num_faces_found = 0 + + if self.options.swap_mode == "first": + face = get_first_face(frame) + + if face is None: + return num_faces_found, frame + + num_faces_found += 1 + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + else: + faces = get_all_faces(frame) + if faces is None: + return num_faces_found, frame + + if self.options.swap_mode == "all": + for face in faces: + num_faces_found += 1 + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + del face + + elif self.options.swap_mode == "selected": + num_targetfaces = len(self.target_face_datas) + use_index = num_targetfaces == 1 + for i,tf in enumerate(self.target_face_datas): + for face in faces: + if compute_cosine_distance(tf.embedding, face.embedding) <= self.options.face_distance_threshold: + if i < len(self.input_face_datas): + if use_index: + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + else: + temp_frame = self.process_face(i, face, temp_frame) + num_faces_found += 1 + del face + if not roop.globals.vr_mode and num_faces_found == num_targetfaces: + break + elif self.options.swap_mode == "all_female" or self.options.swap_mode == "all_male": + gender = 'F' if self.options.swap_mode == "all_female" else 'M' + for face in faces: + if face.sex == gender: + num_faces_found += 1 + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + del face + + if roop.globals.vr_mode and num_faces_found % 2 > 0: + # stereo image, there has to be an even number of faces + num_faces_found = 0 + return num_faces_found, frame + if num_faces_found == 0: + return num_faces_found, frame + + #maskprocessor = next((x for x in self.processors if x.type == 'mask'), None) + + if self.options.imagemask is not None and self.options.imagemask.shape == frame.shape: + temp_frame = self.simple_blend_with_mask(temp_frame, frame, self.options.imagemask) + return num_faces_found, temp_frame + + + def rotation_action(self, original_face:Face, frame:Frame): + (height, width) = frame.shape[:2] + + bounding_box_width = original_face.bbox[2] - original_face.bbox[0] + bounding_box_height = original_face.bbox[3] - original_face.bbox[1] + horizontal_face = bounding_box_width > bounding_box_height + + center_x = width // 2.0 + start_x = original_face.bbox[0] + end_x = original_face.bbox[2] + bbox_center_x = start_x + (bounding_box_width // 2.0) + + # need to leverage the array of landmarks as decribed here: + # https://github.com/deepinsight/insightface/tree/master/alignment/coordinate_reg + # basically, we should be able to check for the relative position of eyes and nose + # then use that to determine which way the face is actually facing when in a horizontal position + # and use that to determine the correct rotation_action + + forehead_x = original_face.landmark_2d_106[72][0] + chin_x = original_face.landmark_2d_106[0][0] + + if horizontal_face: + if chin_x < forehead_x: + # this is someone lying down with their face like this (: + return "rotate_anticlockwise" + elif forehead_x < chin_x: + # this is someone lying down with their face like this :) + return "rotate_clockwise" + if bbox_center_x >= center_x: + # this is someone lying down with their face in the right hand side of the frame + return "rotate_anticlockwise" + if bbox_center_x < center_x: + # this is someone lying down with their face in the left hand side of the frame + return "rotate_clockwise" + + return None + + + def auto_rotate_frame(self, original_face, frame:Frame): + target_face = original_face + original_frame = frame + + rotation_action = self.rotation_action(original_face, frame) + + if rotation_action == "rotate_anticlockwise": + #face is horizontal, rotating frame anti-clockwise and getting face bounding box from rotated frame + frame = rotate_anticlockwise(frame) + elif rotation_action == "rotate_clockwise": + #face is horizontal, rotating frame clockwise and getting face bounding box from rotated frame + frame = rotate_clockwise(frame) + + return target_face, frame, rotation_action + + + def auto_unrotate_frame(self, frame:Frame, rotation_action): + if rotation_action == "rotate_anticlockwise": + return rotate_clockwise(frame) + elif rotation_action == "rotate_clockwise": + return rotate_anticlockwise(frame) + + return frame + + + + def process_face(self,face_index, target_face:Face, frame:Frame): + from roop.face_util import align_crop + + enhanced_frame = None + if(len(self.input_face_datas) > 0): + inputface = self.input_face_datas[face_index].faces[0] + else: + inputface = None + + rotation_action = None + if roop.globals.autorotate_faces: + # check for sideways rotation of face + rotation_action = self.rotation_action(target_face, frame) + if rotation_action is not None: + (startX, startY, endX, endY) = target_face["bbox"].astype("int") + width = endX - startX + height = endY - startY + offs = int(max(width,height) * 0.25) + rotcutframe,startX, startY, endX, endY = self.cutout(frame, startX - offs, startY - offs, endX + offs, endY + offs) + if rotation_action == "rotate_anticlockwise": + rotcutframe = rotate_anticlockwise(rotcutframe) + elif rotation_action == "rotate_clockwise": + rotcutframe = rotate_clockwise(rotcutframe) + # rotate image and re-detect face to correct wonky landmarks + rotface = get_first_face(rotcutframe) + if rotface is None: + rotation_action = None + else: + saved_frame = frame.copy() + frame = rotcutframe + target_face = rotface + + + + # if roop.globals.vr_mode: + # bbox = target_face.bbox + # [orig_width, orig_height, _] = frame.shape + + # # Convert bounding box to ints + # x1, y1, x2, y2 = map(int, bbox) + + # # Determine the center of the bounding box + # x_center = (x1 + x2) / 2 + # y_center = (y1 + y2) / 2 + + # # Normalize coordinates to range [-1, 1] + # x_center_normalized = x_center / (orig_width / 2) - 1 + # y_center_normalized = y_center / (orig_width / 2) - 1 + + # # Convert normalized coordinates to spherical (theta, phi) + # theta = x_center_normalized * 180 # Theta ranges from -180 to 180 degrees + # phi = -y_center_normalized * 90 # Phi ranges from -90 to 90 degrees + + # img = vr.GetPerspective(frame, 90, theta, phi, 1280, 1280) # Generate perspective image + + fake_frame = None + aligned_img, M = align_crop(frame, target_face.kps, 128) + fake_frame = aligned_img + swap_frame = aligned_img + target_face.matrix = M + for p in self.processors: + if p.type == 'swap': + if inputface is not None: + for _ in range(0,self.options.num_swap_steps): + swap_frame = p.Run(inputface, target_face, swap_frame) + fake_frame = swap_frame + scale_factor = 0.0 + elif p.type == 'mask': + fake_frame = self.process_mask(p, aligned_img, fake_frame) + else: + enhanced_frame, scale_factor = p.Run(self.input_face_datas[face_index], target_face, fake_frame) + + upscale = 512 + orig_width = fake_frame.shape[1] + + fake_frame = cv2.resize(fake_frame, (upscale, upscale), cv2.INTER_CUBIC) + mask_offsets = (0,0,0,0,1,20) if inputface is None else inputface.mask_offsets + + + if enhanced_frame is None: + scale_factor = int(upscale / orig_width) + result = self.paste_upscale(fake_frame, fake_frame, target_face.matrix, frame, scale_factor, mask_offsets) + else: + result = self.paste_upscale(fake_frame, enhanced_frame, target_face.matrix, frame, scale_factor, mask_offsets) + + if rotation_action is not None: + fake_frame = self.auto_unrotate_frame(result, rotation_action) + return self.paste_simple(fake_frame, saved_frame, startX, startY) + + return result + + + + + def cutout(self, frame:Frame, start_x, start_y, end_x, end_y): + if start_x < 0: + start_x = 0 + if start_y < 0: + start_y = 0 + if end_x > frame.shape[1]: + end_x = frame.shape[1] + if end_y > frame.shape[0]: + end_y = frame.shape[0] + return frame[start_y:end_y, start_x:end_x], start_x, start_y, end_x, end_y + + def paste_simple(self, src:Frame, dest:Frame, start_x, start_y): + end_x = start_x + src.shape[1] + end_y = start_y + src.shape[0] + + start_x, end_x, start_y, end_y = clamp_cut_values(start_x, end_x, start_y, end_y, dest) + dest[start_y:end_y, start_x:end_x] = src + return dest + + def simple_blend_with_mask(self, image1, image2, mask): + # Blend the images + blended_image = image1.astype(np.float32) * (1.0 - mask) + image2.astype(np.float32) * mask + return blended_image.astype(np.uint8) + + + def paste_upscale(self, fake_face, upsk_face, M, target_img, scale_factor, mask_offsets): + M_scale = M * scale_factor + IM = cv2.invertAffineTransform(M_scale) + + face_matte = np.full((target_img.shape[0],target_img.shape[1]), 255, dtype=np.uint8) + # Generate white square sized as a upsk_face + img_matte = np.zeros((upsk_face.shape[0],upsk_face.shape[1]), dtype=np.uint8) + + w = img_matte.shape[1] + h = img_matte.shape[0] + + top = int(mask_offsets[0] * h) + bottom = int(h - (mask_offsets[1] * h)) + left = int(mask_offsets[2] * w) + right = int(w - (mask_offsets[3] * w)) + img_matte[top:bottom,left:right] = 255 + + # Transform white square back to target_img + img_matte = cv2.warpAffine(img_matte, IM, (target_img.shape[1], target_img.shape[0]), flags=cv2.INTER_NEAREST, borderValue=0.0) + ##Blacken the edges of face_matte by 1 pixels (so the mask in not expanded on the image edges) + img_matte[:1,:] = img_matte[-1:,:] = img_matte[:,:1] = img_matte[:,-1:] = 0 + + img_matte = self.blur_area(img_matte, mask_offsets[4], mask_offsets[5]) + #Normalize images to float values and reshape + img_matte = img_matte.astype(np.float32)/255 + face_matte = face_matte.astype(np.float32)/255 + img_matte = np.minimum(face_matte, img_matte) + if self.options.show_face_area_overlay: + # Additional steps for green overlay + green_overlay = np.zeros_like(target_img) + green_color = [0, 255, 0] # RGB for green + for i in range(3): # Apply green color where img_matte is not zero + green_overlay[:, :, i] = np.where(img_matte > 0, green_color[i], 0) ##Transform upcaled face back to target_img + img_matte = np.reshape(img_matte, [img_matte.shape[0],img_matte.shape[1],1]) + paste_face = cv2.warpAffine(upsk_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE) + if upsk_face is not fake_face: + fake_face = cv2.warpAffine(fake_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE) + paste_face = cv2.addWeighted(paste_face, self.options.blend_ratio, fake_face, 1.0 - self.options.blend_ratio, 0) + + # Re-assemble image + paste_face = img_matte * paste_face + paste_face = paste_face + (1-img_matte) * target_img.astype(np.float32) + if self.options.show_face_area_overlay: + # Overlay the green overlay on the final image + paste_face = cv2.addWeighted(paste_face.astype(np.uint8), 1 - 0.5, green_overlay, 0.5, 0) + return paste_face.astype(np.uint8) + + + def blur_area(self, img_matte, num_erosion_iterations, blur_amount): + # Detect the affine transformed white area + mask_h_inds, mask_w_inds = np.where(img_matte==255) + # Calculate the size (and diagonal size) of transformed white area width and height boundaries + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h*mask_w)) + # Calculate the kernel size for eroding img_matte by kernel (insightface empirical guess for best size was max(mask_size//10,10)) + # k = max(mask_size//12, 8) + k = max(mask_size//(blur_amount // 2) , blur_amount // 2) + kernel = np.ones((k,k),np.uint8) + img_matte = cv2.erode(img_matte,kernel,iterations = num_erosion_iterations) + #Calculate the kernel size for blurring img_matte by blur_size (insightface empirical guess for best size was max(mask_size//20, 5)) + # k = max(mask_size//24, 4) + k = max(mask_size//blur_amount, blur_amount//5) + kernel_size = (k, k) + blur_size = tuple(2*i+1 for i in kernel_size) + return cv2.GaussianBlur(img_matte, blur_size, 0) + + + def process_mask(self, processor, frame:Frame, target:Frame): + img_mask = processor.Run(frame, self.options.masking_text) + img_mask = cv2.resize(img_mask, (target.shape[1], target.shape[0])) + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + if self.options.show_face_masking: + result = (1 - img_mask) * frame.astype(np.float32) + return np.uint8(result) + + + target = target.astype(np.float32) + result = (1-img_mask) * target + result += img_mask * frame.astype(np.float32) + return np.uint8(result) + + + + + def unload_models(): + pass + + + def release_resources(self): + for p in self.processors: + p.Release() + self.processors.clear() + diff --git a/roop-unleashed/roop/ProcessOptions.py b/roop-unleashed/roop/ProcessOptions.py new file mode 100644 index 0000000000000000000000000000000000000000..296e8b243796408555a885d11548278ef6ca363c --- /dev/null +++ b/roop-unleashed/roop/ProcessOptions.py @@ -0,0 +1,13 @@ +class ProcessOptions: + + def __init__(self, processordefines:dict, face_distance, blend_ratio, swap_mode, selected_index, masking_text, imagemask, num_steps, show_face_area, show_mask=False): + self.processors = processordefines + self.face_distance_threshold = face_distance + self.blend_ratio = blend_ratio + self.swap_mode = swap_mode + self.selected_index = selected_index + self.masking_text = masking_text + self.imagemask = imagemask + self.num_swap_steps = num_steps + self.show_face_area_overlay = show_face_area + self.show_face_masking = show_mask \ No newline at end of file diff --git a/pretrained_weights/.huggingface/download/denoising_unet.pth.lock b/roop-unleashed/roop/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pretrained_weights/.huggingface/download/denoising_unet.pth.lock rename to roop-unleashed/roop/__init__.py diff --git a/roop-unleashed/roop/__pycache__/FaceSet.cpython-310.pyc b/roop-unleashed/roop/__pycache__/FaceSet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20b0c389973dcb60a1260e27919e55236e01b130 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/FaceSet.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/ProcessEntry.cpython-310.pyc b/roop-unleashed/roop/__pycache__/ProcessEntry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47dc20d22dcce386a06497cd49e4600da54cd364 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/ProcessEntry.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/ProcessMgr.cpython-310.pyc b/roop-unleashed/roop/__pycache__/ProcessMgr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b0187480d30957512149ee617165a11a522b560 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/ProcessMgr.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/ProcessOptions.cpython-310.pyc b/roop-unleashed/roop/__pycache__/ProcessOptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a0a5ff2a01ff19c933d486ade5803e17df1dc1b Binary files /dev/null and b/roop-unleashed/roop/__pycache__/ProcessOptions.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/__init__.cpython-310.pyc b/roop-unleashed/roop/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c819adf07f6fc0e20d6e81b94489ecbc74604609 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/__init__.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/capturer.cpython-310.pyc b/roop-unleashed/roop/__pycache__/capturer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8ce285cbbd58bd2468bb615c0f02c01b22f3e66 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/capturer.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/core.cpython-310.pyc b/roop-unleashed/roop/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..236976f3064aa105cabc315fc566febdb0334201 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/core.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/face_util.cpython-310.pyc b/roop-unleashed/roop/__pycache__/face_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcb2830f446baff1378ef151f49212d234c44903 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/face_util.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/ffmpeg_writer.cpython-310.pyc b/roop-unleashed/roop/__pycache__/ffmpeg_writer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d566c96f7a400b9af0750b691745c2daad84570 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/ffmpeg_writer.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/globals.cpython-310.pyc b/roop-unleashed/roop/__pycache__/globals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..960c7ba209f2c183cb4983b0f1acad0257766daf Binary files /dev/null and b/roop-unleashed/roop/__pycache__/globals.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/metadata.cpython-310.pyc b/roop-unleashed/roop/__pycache__/metadata.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a22e92f0d3f6c429d658d6da01c52dcfb3fbdd39 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/metadata.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/template_parser.cpython-310.pyc b/roop-unleashed/roop/__pycache__/template_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d3dc40a001cdb50c0ab88d1dd6f188f5315447b Binary files /dev/null and b/roop-unleashed/roop/__pycache__/template_parser.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/typing.cpython-310.pyc b/roop-unleashed/roop/__pycache__/typing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86161bb6f7b76f6f065ef5590d1ac3d1ca0de4be Binary files /dev/null and b/roop-unleashed/roop/__pycache__/typing.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/util_ffmpeg.cpython-310.pyc b/roop-unleashed/roop/__pycache__/util_ffmpeg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1632f2e630b16e99163ccddc73dc885079a115e Binary files /dev/null and b/roop-unleashed/roop/__pycache__/util_ffmpeg.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/utilities.cpython-310.pyc b/roop-unleashed/roop/__pycache__/utilities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af28a5a995315b143f0e64961af04be17752651 Binary files /dev/null and b/roop-unleashed/roop/__pycache__/utilities.cpython-310.pyc differ diff --git a/roop-unleashed/roop/__pycache__/vr_util.cpython-310.pyc b/roop-unleashed/roop/__pycache__/vr_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc7ad38a56fd93d4166a0888cc4540be471a94b Binary files /dev/null and b/roop-unleashed/roop/__pycache__/vr_util.cpython-310.pyc differ diff --git a/roop-unleashed/roop/capturer.py b/roop-unleashed/roop/capturer.py new file mode 100644 index 0000000000000000000000000000000000000000..6da3ac082fb0b2498e253c05ecae429f81fd1c70 --- /dev/null +++ b/roop-unleashed/roop/capturer.py @@ -0,0 +1,30 @@ +from typing import Optional +import cv2 +import numpy as np + +from roop.typing import Frame + +def get_image_frame(filename: str): + try: + return cv2.imdecode(np.fromfile(filename, dtype=np.uint8), cv2.IMREAD_COLOR) + except: + print(f"Exception reading {filename}") + return None + + +def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Frame]: + capture = cv2.VideoCapture(video_path) + frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) + capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) + has_frame, frame = capture.read() + capture.release() + if has_frame: + return frame + return None + + +def get_video_frame_total(video_path: str) -> int: + capture = cv2.VideoCapture(video_path) + video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + capture.release() + return video_frame_total diff --git a/roop-unleashed/roop/core.py b/roop-unleashed/roop/core.py new file mode 100755 index 0000000000000000000000000000000000000000..58091faba06fbc5bbdb8fa3a66de909157153927 --- /dev/null +++ b/roop-unleashed/roop/core.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 + +import os +import sys +import shutil +import argparse +import warnings +from typing import List +import platform +import signal +import torch +import onnxruntime +import pathlib +from time import time +import roop.globals +import roop.metadata +import roop.utilities as util +import roop.util_ffmpeg as ffmpeg +from settings import Settings +from roop.face_util import extract_face_images +from roop.ProcessEntry import ProcessEntry +from roop.ProcessMgr import ProcessMgr +from roop.ProcessOptions import ProcessOptions +from roop.capturer import get_video_frame_total +from roop.FaceSet import FaceSet + +process_mgr = None + +if 'ROCMExecutionProvider' in roop.globals.execution_providers: + del torch + +warnings.filterwarnings('ignore', category=FutureWarning, module='insightface') +warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run Roop from the command line") + parser.add_argument('--source_path', type=str, required=True, help="Path to the source file") + parser.add_argument('--target_path', type=str, required=True, help="Path to the target file") + parser.add_argument('--output_path', type=str, required=True, help="Path to save the output file") + parser.add_argument('--execution_provider', type=str, default='CPUExecutionProvider', help="Execution provider for ONNX runtime") + parser.add_argument('--max_memory', type=int, default=None, help="Max memory to use (in GB)") + parser.add_argument('--distance_threshold', type=float, default=0.6, help="Distance threshold for face matching") + parser.add_argument('--blend_ratio', type=float, default=0.5, help="Blend ratio for face swapping") + parser.add_argument('--face_swap_mode', type=str, default='replace', help="Face swap mode") + parser.add_argument('--output_image_format', type=str, default='png', help="Output image format") + parser.add_argument('--output_video_format', type=str, default='mp4', help="Output video format") + parser.add_argument('--execution_threads', type=int, default=8, help="Number of threads to use for execution") + parser.add_argument('--skip_audio', action='store_true', help="Skip audio when processing video") + return parser.parse_args() + + +def encode_execution_providers(execution_providers: List[str]) -> List[str]: + return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers] + + +def decode_execution_providers(execution_providers: List[str]) -> List[str]: + return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers())) + if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)] + + +def suggest_max_memory() -> int: + if platform.system().lower() == 'darwin': + return 4 + return 16 + + +def suggest_execution_providers() -> List[str]: + return encode_execution_providers(onnxruntime.get_available_providers()) + + +def suggest_execution_threads() -> int: + if 'DmlExecutionProvider' in roop.globals.execution_providers: + return 1 + if 'ROCMExecutionProvider' in roop.globals.execution_providers: + return 1 + return 8 + + +def limit_resources() -> None: + # limit memory usage + if roop.globals.max_memory: + memory = roop.globals.max_memory * 1024 ** 3 + if platform.system().lower() == 'darwin': + memory = roop.globals.max_memory * 1024 ** 6 + if platform.system().lower() == 'windows': + import ctypes + kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined] + kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory)) + else: + import resource + resource.setrlimit(resource.RLIMIT_DATA, (memory, memory)) + + +def release_resources() -> None: + import gc + global process_mgr + + if process_mgr is not None: + process_mgr.release_resources() + process_mgr = None + + gc.collect() + + +def pre_check() -> bool: + if sys.version_info < (3, 9): + update_status('Python version is not supported - please upgrade to 3.9 or higher.') + return False + + download_directory_path = util.resolve_relative_path('../models') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/inswapper_128.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GFPGANv1.4.onnx']) + util.conditional_download(download_directory_path, ['https://github.com/csxmli2016/DMDNet/releases/download/v1/DMDNet.pth']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GPEN-BFR-512.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/restoreformer_plus_plus.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/xseg.onnx']) + download_directory_path = util.resolve_relative_path('../models/CLIP') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/rd64-uni-refined.pth']) + download_directory_path = util.resolve_relative_path('../models/CodeFormer') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/CodeFormerv0.1.onnx']) + download_directory_path = util.resolve_relative_path('../models/Frame') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_artistic.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_stable.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/isnet-general-use.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x4.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x2.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/lsdir_x4.onnx']) + + if not shutil.which('ffmpeg'): + update_status('ffmpeg is not installed.') + return True + + +def update_status(message: str) -> None: + print(message) + + +def get_processing_plugins(masking_engine): + processors = {"faceswap": {}} + if masking_engine is not None: + processors.update({masking_engine: {}}) + + if roop.globals.selected_enhancer == 'GFPGAN': + processors.update({"gfpgan": {}}) + elif roop.globals.selected_enhancer == 'Codeformer': + processors.update({"codeformer": {}}) + elif roop.globals.selected_enhancer == 'DMDNet': + processors.update({"dmdnet": {}}) + elif roop.globals.selected_enhancer == 'GPEN': + processors.update({"gpen": {}}) + elif roop.globals.selected_enhancer == 'Restoreformer++': + processors.update({"restoreformer++": {}}) + return processors + + +def live_swap(frame, options): + global process_mgr + + if frame is None: + return frame + + if process_mgr is None: + process_mgr = ProcessMgr(None) + + process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options) + newframe = process_mgr.process_frame(frame) + if newframe is None: + return frame + return newframe + + +def batch_process_regular(files: List[ProcessEntry], masking_engine: str, new_clip_text: str, use_new_method, imagemask, num_swap_steps, progress, selected_index=0) -> None: + global process_mgr + + release_resources() + limit_resources() + if process_mgr is None: + process_mgr = ProcessMgr(progress) + mask = imagemask["layers"][0] if imagemask is not None else None + if len(roop.globals.INPUT_FACESETS) <= selected_index: + selected_index = 0 + options = ProcessOptions(get_processing_plugins(masking_engine), roop.globals.distance_threshold, roop.globals.blend_ratio, roop.globals.face_swap_mode, selected_index, new_clip_text, mask, num_swap_steps, False) + process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options) + batch_process(files, use_new_method) + return + + +def batch_process_with_options(files: List[ProcessEntry], options, progress): + global process_mgr + + release_resources() + limit_resources() + if process_mgr is None: + process_mgr = ProcessMgr(progress) + process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options) + roop.globals.keep_frames = False + roop.globals.wait_after_extraction = False + roop.globals.skip_audio = False + batch_process(files, True) + + +def batch_process(files: List[ProcessEntry], use_new_method) -> None: + global process_mgr + + roop.globals.processing = True + + max_threads = suggest_execution_threads() + if max_threads == 1: + roop.globals.execution_threads = 1 + + imagefiles: List[ProcessEntry] = [] + videofiles: List[ProcessEntry] = [] + + update_status('Sorting videos/images') + + for index, f in enumerate(files): + fullname = f.filename + if util.has_image_extension(fullname): + destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'.{roop.globals.CFG.output_image_format}') + destination = util.replace_template(destination, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True) + f.finalname = destination + imagefiles.append(f) + + elif util.is_video(fullname) or util.has_extension(fullname, ['gif']): + destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'__temp.{roop.globals.CFG.output_video_format}') + f.finalname = destination + videofiles.append(f) + + if len(imagefiles) > 0: + update_status('Processing image(s)') + origimages = [] + fakeimages = [] + for f in imagefiles: + origimages.append(f.filename) + fakeimages.append(f.finalname) + + process_mgr.run_batch(origimages, fakeimages, roop.globals.execution_threads) + origimages.clear() + fakeimages.clear() + + if len(videofiles) > 0: + for index, v in enumerate(videofiles): + if not roop.globals.processing: + end_processing('Processing stopped!') + return + fps = v.fps if v.fps > 0 else util.detect_fps(v.filename) + if v.endframe == 0: + v.endframe = get_video_frame_total(v.filename) + + update_status(f'Creating {os.path.basename(v.finalname)} with {fps} FPS...') + start_processing = time() + if roop.globals.keep_frames or not use_new_method: + util.create_temp(v.filename) + update_status('Extracting frames...') + ffmpeg.extract_frames(v.filename, v.startframe, v.endframe, fps) + if not roop.globals.processing: + end_processing('Processing stopped!') + return + + temp_frame_paths = util.get_temp_frame_paths(v.filename) + process_mgr.run_batch(temp_frame_paths, temp_frame_paths, roop.globals.execution_threads) + if not roop.globals.processing: + end_processing('Processing stopped!') + return + if roop.globals.wait_after_extraction: + extract_path = os.path.dirname(temp_frame_paths[0]) + util.open_folder(extract_path) + input("Press any key to continue...") + print("Resorting frames to create video") + util.sort_rename_frames(extract_path) + + ffmpeg.create_video(v.filename, v.finalname, fps) + if not roop.globals.keep_frames: + util.delete_temp_frames(temp_frame_paths[0]) + else: + if util.has_extension(v.filename, ['gif']): + skip_audio = True + else: + skip_audio = roop.globals.skip_audio + process_mgr.run_batch_inmem(v.filename, v.finalname, v.startframe, v.endframe, fps, roop.globals.execution_threads, skip_audio) + + if not roop.globals.processing: + end_processing('Processing stopped!') + return + + video_file_name = v.finalname + if os.path.isfile(video_file_name): + destination = '' + if util.has_extension(v.filename, ['gif']): + gifname = util.get_destfilename_from_path(v.filename, roop.globals.output_path, '.gif') + destination = util.replace_template(gifname, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True) + + update_status('Creating final GIF') + ffmpeg.create_gif_from_video(video_file_name, destination) + if os.path.isfile(destination): + os.remove(video_file_name) + else: + skip_audio = roop.globals.skip_audio + destination = util.replace_template(video_file_name, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True) + + if not skip_audio: + ffmpeg.restore_audio(video_file_name, v.filename, v.startframe, v.endframe, destination) + if os.path.isfile(destination): + os.remove(video_file_name) + else: + shutil.move(video_file_name, destination) + update_status(f'\nProcessing {os.path.basename(destination)} took {time() - start_processing} secs') + + else: + update_status(f'Failed processing {os.path.basename(v.finalname)}!') + end_processing('Finished') + + +def end_processing(msg: str): + update_status(msg) + roop.globals.target_folder_path = None + release_resources() + + +def destroy() -> None: + if roop.globals.target_path: + util.clean_temp(roop.globals.target_path) + release_resources() + sys.exit() + + +def run() -> None: + args = parse_args() + + roop.globals.source_path = args.source_path + roop.globals.target_path = args.target_path + roop.globals.output_path = args.output_path + roop.globals.execution_providers = decode_execution_providers([args.execution_provider]) + roop.globals.max_memory = args.max_memory + roop.globals.distance_threshold = args.distance_threshold + roop.globals.blend_ratio = args.blend_ratio + roop.globals.face_swap_mode = args.face_swap_mode + roop.globals.CFG = Settings('config.yaml') + roop.globals.execution_threads = args.execution_threads + roop.globals.output_image_format = args.output_image_format + roop.globals.output_video_format = args.output_video_format + roop.globals.skip_audio = args.skip_audio + roop.globals.face_swap_mode == 'selected' + # Ensure these values are set + if not roop.globals.video_encoder: + roop.globals.video_encoder = 'libx264' # or another suitable default value + if not roop.globals.video_quality: + roop.globals.video_quality = 23 # or another suitable default value + + signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) + + if not pre_check(): + return + + # Extract faces from the source and target files and create FaceSet objects + source_faces = extract_face_images(args.source_path, (False, 0)) + target_faces = extract_face_images(args.target_path, (False, util.has_image_extension(args.target_path))) + print("Number of targets faces is ", target_faces.count) + + if source_faces: + source_face_set = FaceSet() + for face_data in source_faces: + face = face_data[0] + face.mask_offsets = (0, 0, 0, 0, 1, 20) + source_face_set.faces.append(face) + if len(source_face_set.faces) > 1: + source_face_set.AverageEmbeddings() + roop.globals.INPUT_FACESETS.append(source_face_set) + + if target_faces: + target_face_set = FaceSet() + for face_data in target_faces: + face = face_data[0] + face.mask_offsets = (0, 0, 0, 0, 1, 20) + target_face_set.faces.append(face) + if len(target_face_set.faces) > 1: + target_face_set.AverageEmbeddings() + roop.globals.TARGET_FACES.append(target_face_set.faces[0]) # Assuming using the first face for target + + # Detect fps and endframe values for the source and target videos + source_fps = util.detect_fps(args.source_path) + source_endframe = get_video_frame_total(args.source_path) + target_fps = util.detect_fps(args.target_path) + target_endframe = get_video_frame_total(args.target_path) + + # Initialize ProcessEntry objects using detected values + source_entry = ProcessEntry( + filename=args.source_path, + start=0, + end=source_endframe, + fps=source_fps + ) + + target_entry = ProcessEntry( + filename=args.target_path, + start=0, + end=target_endframe, + fps=target_fps + ) + + files = [source_entry, target_entry] + batch_process_regular(files, None, None, False, None, 1, None) diff --git a/roop-unleashed/roop/face_util.py b/roop-unleashed/roop/face_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d870632d6d83cf3a007ae065f76a0ded8ea17732 --- /dev/null +++ b/roop-unleashed/roop/face_util.py @@ -0,0 +1,310 @@ +import threading +from typing import Any +import insightface + +import roop.globals +from roop.typing import Frame, Face + +import cv2 +import numpy as np +from skimage import transform as trans +from roop.capturer import get_video_frame +from roop.utilities import resolve_relative_path, conditional_download + +FACE_ANALYSER = None +THREAD_LOCK_ANALYSER = threading.Lock() +THREAD_LOCK_SWAPPER = threading.Lock() +FACE_SWAPPER = None + + +def get_face_analyser() -> Any: + global FACE_ANALYSER + + with THREAD_LOCK_ANALYSER: + if FACE_ANALYSER is None or roop.globals.g_current_face_analysis != roop.globals.g_desired_face_analysis: + model_path = resolve_relative_path('..') + # removed genderage + allowed_modules = roop.globals.g_desired_face_analysis + roop.globals.g_current_face_analysis = roop.globals.g_desired_face_analysis + if roop.globals.CFG.force_cpu: + print("Forcing CPU for Face Analysis") + FACE_ANALYSER = insightface.app.FaceAnalysis( + name="buffalo_l", + root=model_path, providers=["CPUExecutionProvider"],allowed_modules=allowed_modules + ) + else: + FACE_ANALYSER = insightface.app.FaceAnalysis( + name="buffalo_l", root=model_path, providers=roop.globals.execution_providers,allowed_modules=allowed_modules + ) + FACE_ANALYSER.prepare( + ctx_id=0, + det_size=(640, 640) if roop.globals.default_det_size else (320, 320), + ) + return FACE_ANALYSER + + +def get_first_face(frame: Frame) -> Any: + try: + faces = get_face_analyser().get(frame) + return min(faces, key=lambda x: x.bbox[0]) + # return sorted(faces, reverse=True, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[0] + except: + return None + + +def get_all_faces(frame: Frame) -> Any: + try: + faces = get_face_analyser().get(frame) + return sorted(faces, key=lambda x: x.bbox[0]) + except: + return None + + +def extract_face_images(source_filename, video_info, extra_padding=-1.0): + face_data = [] + source_image = None + + if video_info[0]: + frame = get_video_frame(source_filename, video_info[1]) + if frame is not None: + source_image = frame + else: + return face_data + else: + source_image = cv2.imdecode(np.fromfile(source_filename, dtype=np.uint8), cv2.IMREAD_COLOR) + + if source_image is None: + print("No source image!") + + faces = get_all_faces(source_image) + if faces is None: + print("NO faces here!") + return face_data + + i = 0 + for face in faces: + (startX, startY, endX, endY) = face["bbox"].astype("int") + startX, endX, startY, endY = clamp_cut_values(startX, endX, startY, endY, source_image) + if extra_padding > 0.0: + if source_image.shape[:2] == (512, 512): + i += 1 + face_data.append([face, source_image]) + continue + + found = False + for i in range(1, 3): + (startX, startY, endX, endY) = face["bbox"].astype("int") + startX, endX, startY, endY = clamp_cut_values(startX, endX, startY, endY, source_image) + cutout_padding = extra_padding + # top needs extra room for detection + padding = int((endY - startY) * cutout_padding) + oldY = startY + startY -= padding + + factor = 0.25 if i == 1 else 0.5 + cutout_padding = factor + padding = int((endY - oldY) * cutout_padding) + endY += padding + padding = int((endX - startX) * cutout_padding) + startX -= padding + endX += padding + startX, endX, startY, endY = clamp_cut_values( + startX, endX, startY, endY, source_image + ) + face_temp = source_image[startY:endY, startX:endX] + face_temp = resize_image_keep_content(face_temp) + testfaces = get_all_faces(face_temp) + if testfaces is not None and len(testfaces) > 0: + i += 1 + face_data.append([testfaces[0], face_temp]) + found = True + break + + if not found: + print("No face found after resizing, this shouldn't happen!") + continue + + face_temp = source_image[startY:endY, startX:endX] + if face_temp.size < 1: + continue + + i += 1 + face_data.append([face, face_temp]) + return face_data + + +def clamp_cut_values(startX, endX, startY, endY, image): + if startX < 0: + startX = 0 + if endX > image.shape[1]: + endX = image.shape[1] + if startY < 0: + startY = 0 + if endY > image.shape[0]: + endY = image.shape[0] + return startX, endX, startY, endY + + + +def face_offset_top(face: Face, offset): + face["bbox"][1] += offset + face["bbox"][3] += offset + lm106 = face.landmark_2d_106 + add = np.full_like(lm106, [0, offset]) + face["landmark_2d_106"] = lm106 + add + return face + + +def resize_image_keep_content(image, new_width=512, new_height=512): + dim = None + (h, w) = image.shape[:2] + if h > w: + r = new_height / float(h) + dim = (int(w * r), new_height) + else: + # Calculate the ratio of the width and construct the dimensions + r = new_width / float(w) + dim = (new_width, int(h * r)) + image = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) + (h, w) = image.shape[:2] + if h == new_height and w == new_width: + return image + resize_img = np.zeros(shape=(new_height, new_width, 3), dtype=image.dtype) + offs = (new_width - w) if h == new_height else (new_height - h) + startoffs = int(offs // 2) if offs % 2 == 0 else int(offs // 2) + 1 + offs = int(offs // 2) + + if h == new_height: + resize_img[0:new_height, startoffs : new_width - offs] = image + else: + resize_img[startoffs : new_height - offs, 0:new_width] = image + return resize_img + + +def rotate_image_90(image, rotate=True): + if rotate: + return np.rot90(image) + else: + return np.rot90(image, 1, (1, 0)) + + +def rotate_anticlockwise(frame): + return rotate_image_90(frame) + + +def rotate_clockwise(frame): + return rotate_image_90(frame, False) + + +def rotate_image_180(image): + return np.flip(image, 0) + + +# alignment code from insightface https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py + +arcface_dst = np.array( + [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], + ], + dtype=np.float32, +) + + +def estimate_norm(lmk, image_size=112, mode="arcface"): + assert lmk.shape == (5, 2) + assert image_size % 112 == 0 or image_size % 128 == 0 + if image_size % 112 == 0: + ratio = float(image_size) / 112.0 + diff_x = 0 + else: + ratio = float(image_size) / 128.0 + diff_x = 8.0 * ratio + dst = arcface_dst * ratio + dst[:, 0] += diff_x + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + + + +# aligned, M = norm_crop2(f[1], face.kps, 512) +def align_crop(img, landmark, image_size=112, mode="arcface"): + M = estimate_norm(landmark, image_size, mode) + warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) + return warped, M + + +def square_crop(im, S): + if im.shape[0] > im.shape[1]: + height = S + width = int(float(im.shape[1]) / im.shape[0] * S) + scale = float(S) / im.shape[0] + else: + width = S + height = int(float(im.shape[0]) / im.shape[1] * S) + scale = float(S) / im.shape[1] + resized_im = cv2.resize(im, (width, height)) + det_im = np.zeros((S, S, 3), dtype=np.uint8) + det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im + return det_im, scale + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + # print('new_pt', new_pt.shape, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +def trans_points3d(pts, M): + scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) + # print(scale) + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + # print('new_pt', new_pt.shape, new_pt) + new_pts[i][0:2] = new_pt[0:2] + new_pts[i][2] = pts[i][2] * scale + + return new_pts + + +def trans_points(pts, M): + if pts.shape[1] == 2: + return trans_points2d(pts, M) + else: + return trans_points3d(pts, M) + +def create_blank_image(width, height): + img = np.zeros((height, width, 4), dtype=np.uint8) + img[:] = [0,0,0,0] + return img + diff --git a/roop-unleashed/roop/ffmpeg_writer.py b/roop-unleashed/roop/ffmpeg_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..9642efad2de4e2b3463a62d1ee04b5f02402702c --- /dev/null +++ b/roop-unleashed/roop/ffmpeg_writer.py @@ -0,0 +1,218 @@ +""" +FFMPEG_Writer - write set of frames to video file + +original from +https://github.com/Zulko/moviepy/blob/master/moviepy/video/io/ffmpeg_writer.py + +removed unnecessary dependencies + +The MIT License (MIT) + +Copyright (c) 2015 Zulko +Copyright (c) 2023 Janvarev Vladislav +""" + +import os +import subprocess as sp + +PIPE = -1 +STDOUT = -2 +DEVNULL = -3 + +FFMPEG_BINARY = "ffmpeg" + +class FFMPEG_VideoWriter: + """ A class for FFMPEG-based video writing. + + A class to write videos using ffmpeg. ffmpeg will write in a large + choice of formats. + + Parameters + ----------- + + filename + Any filename like 'video.mp4' etc. but if you want to avoid + complications it is recommended to use the generic extension + '.avi' for all your videos. + + size + Size (width,height) of the output video in pixels. + + fps + Frames per second in the output video file. + + codec + FFMPEG codec. It seems that in terms of quality the hierarchy is + 'rawvideo' = 'png' > 'mpeg4' > 'libx264' + 'png' manages the same lossless quality as 'rawvideo' but yields + smaller files. Type ``ffmpeg -codecs`` in a terminal to get a list + of accepted codecs. + + Note for default 'libx264': by default the pixel format yuv420p + is used. If the video dimensions are not both even (e.g. 720x405) + another pixel format is used, and this can cause problem in some + video readers. + + audiofile + Optional: The name of an audio file that will be incorporated + to the video. + + preset + Sets the time that FFMPEG will take to compress the video. The slower, + the better the compression rate. Possibilities are: ultrafast,superfast, + veryfast, faster, fast, medium (default), slow, slower, veryslow, + placebo. + + bitrate + Only relevant for codecs which accept a bitrate. "5000k" offers + nice results in general. + + """ + + def __init__(self, filename, size, fps, codec="libx265", crf=14, audiofile=None, + preset="medium", bitrate=None, + logfile=None, threads=None, ffmpeg_params=None): + + if logfile is None: + logfile = sp.PIPE + + self.filename = filename + self.codec = codec + self.ext = self.filename.split(".")[-1] + w = size[0] - 1 if size[0] % 2 != 0 else size[0] + h = size[1] - 1 if size[1] % 2 != 0 else size[1] + + + # order is important + cmd = [ + FFMPEG_BINARY, + '-hide_banner', + '-hwaccel', 'auto', + '-y', + '-loglevel', 'error' if logfile == sp.PIPE else 'info', + '-f', 'rawvideo', + '-vcodec', 'rawvideo', + '-s', '%dx%d' % (size[0], size[1]), + #'-pix_fmt', 'rgba' if withmask else 'rgb24', + '-pix_fmt', 'bgr24', + '-r', str(fps), + '-an', '-i', '-' + ] + + if audiofile is not None: + cmd.extend([ + '-i', audiofile, + '-acodec', 'copy' + ]) + + cmd.extend([ + '-vcodec', codec, + '-crf', str(crf) + #'-preset', preset, + ]) + if ffmpeg_params is not None: + cmd.extend(ffmpeg_params) + if bitrate is not None: + cmd.extend([ + '-b', bitrate + ]) + + # scale to a resolution divisible by 2 if not even + cmd.extend(['-vf', f'scale={w}:{h}' if w != size[0] or h != size[1] else 'colorspace=bt709:iall=bt601-6-625:fast=1']) + + if threads is not None: + cmd.extend(["-threads", str(threads)]) + + cmd.extend([ + '-pix_fmt', 'yuv420p', + + ]) + cmd.extend([ + filename + ]) + + test = str(cmd) + print(test) + + popen_params = {"stdout": DEVNULL, + "stderr": logfile, + "stdin": sp.PIPE} + + # This was added so that no extra unwanted window opens on windows + # when the child process is created + if os.name == "nt": + popen_params["creationflags"] = 0x08000000 # CREATE_NO_WINDOW + + self.proc = sp.Popen(cmd, **popen_params) + + + def write_frame(self, img_array): + """ Writes one frame in the file.""" + try: + #if PY3: + self.proc.stdin.write(img_array.tobytes()) + # else: + # self.proc.stdin.write(img_array.tostring()) + except IOError as err: + _, ffmpeg_error = self.proc.communicate() + error = (str(err) + ("\n\nroop unleashed error: FFMPEG encountered " + "the following error while writing file %s:" + "\n\n %s" % (self.filename, str(ffmpeg_error)))) + + if b"Unknown encoder" in ffmpeg_error: + + error = error+("\n\nThe video export " + "failed because FFMPEG didn't find the specified " + "codec for video encoding (%s). Please install " + "this codec or change the codec when calling " + "write_videofile. For instance:\n" + " >>> clip.write_videofile('myvid.webm', codec='libvpx')")%(self.codec) + + elif b"incorrect codec parameters ?" in ffmpeg_error: + + error = error+("\n\nThe video export " + "failed, possibly because the codec specified for " + "the video (%s) is not compatible with the given " + "extension (%s). Please specify a valid 'codec' " + "argument in write_videofile. This would be 'libx264' " + "or 'mpeg4' for mp4, 'libtheora' for ogv, 'libvpx for webm. " + "Another possible reason is that the audio codec was not " + "compatible with the video codec. For instance the video " + "extensions 'ogv' and 'webm' only allow 'libvorbis' (default) as a" + "video codec." + )%(self.codec, self.ext) + + elif b"encoder setup failed" in ffmpeg_error: + + error = error+("\n\nThe video export " + "failed, possibly because the bitrate you specified " + "was too high or too low for the video codec.") + + elif b"Invalid encoder type" in ffmpeg_error: + + error = error + ("\n\nThe video export failed because the codec " + "or file extension you provided is not a video") + + + raise IOError(error) + + def close(self): + if self.proc: + self.proc.stdin.close() + if self.proc.stderr is not None: + self.proc.stderr.close() + self.proc.wait() + + self.proc = None + + # Support the Context Manager protocol, to ensure that resources are cleaned up. + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + + + diff --git a/roop-unleashed/roop/globals.py b/roop-unleashed/roop/globals.py new file mode 100644 index 0000000000000000000000000000000000000000..b1228e3d0652def3c164b332120f8ad2d20292af --- /dev/null +++ b/roop-unleashed/roop/globals.py @@ -0,0 +1,53 @@ +from settings import Settings +from typing import List + +source_path = None +target_path = None +output_path = None +target_folder_path = None + +frame_processors: List[str] = [] +keep_fps = None +keep_frames = None +autorotate_faces = None +vr_mode = None +skip_audio = None +wait_after_extraction = None +many_faces = None +use_batch = None +source_face_index = 0 +target_face_index = 0 +face_position = None +video_encoder = None +video_quality = None +max_memory = None +execution_providers: List[str] = [] +execution_threads = None +headless = None +log_level = 'error' +selected_enhancer = None +face_swap_mode = None +blend_ratio = 0.5 +distance_threshold = 0.65 +default_det_size = True + +no_face_action = 0 + +processing = False + +g_current_face_analysis = None +g_desired_face_analysis = None + +FACE_ENHANCER = None + +INPUT_FACESETS = [] +TARGET_FACES = [] + + +IMAGE_CHAIN_PROCESSOR = None +VIDEO_CHAIN_PROCESSOR = None +BATCH_IMAGE_CHAIN_PROCESSOR = None + +CFG: Settings = None + + diff --git a/roop-unleashed/roop/metadata.py b/roop-unleashed/roop/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..469e3990c42b6a278b1d7941bdc4dac53f28c72e --- /dev/null +++ b/roop-unleashed/roop/metadata.py @@ -0,0 +1,2 @@ +name = 'roop unleashed' +version = '4.0.0' diff --git a/roop-unleashed/roop/processors/Enhance_CodeFormer.py b/roop-unleashed/roop/processors/Enhance_CodeFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d00a3d431f6b16a659d5722314b3531a6af425d --- /dev/null +++ b/roop-unleashed/roop/processors/Enhance_CodeFormer.py @@ -0,0 +1,75 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +# THREAD_LOCK = threading.Lock() + + +class Enhance_CodeFormer(): + model_codeformer = None + + plugin_options:dict = None + + processorname = 'codeformer' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_codeformer is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + model_path = resolve_relative_path('../models/CodeFormer/CodeFormerv0.1.onnx') + self.model_codeformer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_codeformer.get_inputs() + model_outputs = self.model_codeformer.get_outputs() + self.io_binding = self.model_codeformer.io_binding() + self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5])) + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + input_size = temp_frame.shape[1] + # preprocess + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame.astype(np.float32)) + self.model_codeformer.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + + # post-process + result = result.transpose((1, 2, 0)) + + un_min = -1.0 + un_max = 1.0 + result = np.clip(result, un_min, un_max) + result = (result - un_min) / (un_max - un_min) + + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + result = (result * 255.0).round() + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + del self.model_codeformer + self.model_codeformer = None + del self.io_binding + self.io_binding = None + diff --git a/roop-unleashed/roop/processors/Enhance_DMDNet.py b/roop-unleashed/roop/processors/Enhance_DMDNet.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6a6bb2d2fdad863dcbf66da8e498555d357a64 --- /dev/null +++ b/roop-unleashed/roop/processors/Enhance_DMDNet.py @@ -0,0 +1,898 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.spectral_norm as SpectralNorm +import threading +from torchvision.ops import roi_align + +from math import sqrt + +from torchvision.transforms.functional import normalize + +from roop.typing import Face, Frame, FaceSet + + +THREAD_LOCK_DMDNET = threading.Lock() + + +class Enhance_DMDNet(): + plugin_options:dict = None + model_dmdnet = None + torchdevice = None + + processorname = 'dmdnet' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_dmdnet is None: + self.model_dmdnet = self.create(self.plugin_options["devicename"]) + + + # temp_frame already cropped+aligned, bbox not + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + input_size = temp_frame.shape[1] + + result = self.enhance_face(source_faceset, temp_frame, target_face) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + self.model_gfpgan = None + + + # https://stackoverflow.com/a/67174339 + def landmarks106_to_68(self, pt106): + map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17, + 43,48,49,51,50, + 102,103,104,105,101, + 72,73,74,86,78,79,80,85,84, + 35,41,42,39,37,36, + 89,95,96,93,91,90, + 52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54 + ] + + pt68 = [] + for i in range(68): + index = map106to68[i] + pt68.append(pt106[index]) + return pt68 + + + + + def check_bbox(self, imgs, boxes): + boxes = boxes.view(-1, 4, 4) + colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)] + i = 0 + for img, box in zip(imgs, boxes): + img = (img + 1)/2 * 255 + img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy() + for idx, point in enumerate(box): + cv2.rectangle(img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2) + cv2.imwrite('dmdnet_{:02d}.png'.format(i), img2) + i += 1 + + + def trans_points2d(self, pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + + def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face): + # preprocess + start_x, start_y, end_x, end_y = map(int, face['bbox']) + lm106 = face.landmark_2d_106 + lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) + + if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512: + # scale to 512x512 + scale_factor = 512 / temp_frame.shape[1] + + M = face.matrix * scale_factor + + lq_landmarks = self.trans_points2d(lq_landmarks, M) + temp_frame = cv2.resize(temp_frame, (512,512), interpolation = cv2.INTER_AREA) + + if temp_frame.ndim == 2: + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG + # else: + # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB + + lq = read_img_tensor(temp_frame) + + LQLocs = get_component_location(lq_landmarks) + # self.check_bbox(lq, LQLocs.unsqueeze(0)) + + # specific, change 1000 to 1 to activate + if len(ref_faceset.faces) > 1: + SpecificImgs = [] + SpecificLocs = [] + for i,face in enumerate(ref_faceset.faces): + lm106 = face.landmark_2d_106 + lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) + ref_image = ref_faceset.ref_images[i] + if ref_image.shape[0] != 512 or ref_image.shape[1] != 512: + # scale to 512x512 + scale_factor = 512 / ref_image.shape[1] + + M = face.matrix * scale_factor + + lq_landmarks = self.trans_points2d(lq_landmarks, M) + ref_image = cv2.resize(ref_image, (512,512), interpolation = cv2.INTER_AREA) + + if ref_image.ndim == 2: + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG + # else: + # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB + + ref_tensor = read_img_tensor(ref_image) + ref_locs = get_component_location(lq_landmarks) + # self.check_bbox(ref_tensor, ref_locs.unsqueeze(0)) + + SpecificImgs.append(ref_tensor) + SpecificLocs.append(ref_locs.unsqueeze(0)) + + SpecificImgs = torch.cat(SpecificImgs, dim=0) + SpecificLocs = torch.cat(SpecificLocs, dim=0) + # check_bbox(SpecificImgs, SpecificLocs) + SpMem256, SpMem128, SpMem64 = self.model_dmdnet.generate_specific_dictionary(sp_imgs = SpecificImgs.to(self.torchdevice), sp_locs = SpecificLocs) + SpMem256Para = {} + SpMem128Para = {} + SpMem64Para = {} + for k, v in SpMem256.items(): + SpMem256Para[k] = v + for k, v in SpMem128.items(): + SpMem128Para[k] = v + for k, v in SpMem64.items(): + SpMem64Para[k] = v + else: + # generic + SpMem256Para, SpMem128Para, SpMem64Para = None, None, None + + with torch.no_grad(): + with THREAD_LOCK_DMDNET: + try: + GenericResult, SpecificResult = self.model_dmdnet(lq = lq.to(self.torchdevice), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para) + except Exception as e: + print(f'Error {e} there may be something wrong with the detected component locations.') + return temp_frame + + if SpecificResult is not None: + save_specific = SpecificResult * 0.5 + 0.5 + save_specific = save_specific.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0 + temp_frame = save_specific.astype("uint8") + if False: + save_generic = GenericResult * 0.5 + 0.5 + save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 + check_lq = lq * 0.5 + 0.5 + check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0 + cv2.imwrite('dmdnet_comparison.png', cv2.cvtColor(np.hstack((check_lq, save_generic, save_specific)),cv2.COLOR_RGB2BGR)) + else: + save_generic = GenericResult * 0.5 + 0.5 + save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 + temp_frame = save_generic.astype("uint8") + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB + return temp_frame + + + + def create(self, devicename): + self.torchdevice = torch.device(devicename) + model_dmdnet = DMDNet().to(self.torchdevice) + weights = torch.load('./models/DMDNet.pth') + model_dmdnet.load_state_dict(weights, strict=True) + + model_dmdnet.eval() + num_params = 0 + for param in model_dmdnet.parameters(): + num_params += param.numel() + return model_dmdnet + + # print('{:>8s} : {}'.format('Using device', device)) + # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6)) + + + +def read_img_tensor(Img=None): #rgb -1~1 + Img = Img.transpose((2, 0, 1))/255.0 + Img = torch.from_numpy(Img).float() + normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True) + ImgTensor = Img.unsqueeze(0) + return ImgTensor + + +def get_component_location(Landmarks, re_read=False): + if re_read: + ReadLandmark = [] + with open(Landmarks,'r') as f: + for line in f: + tmp = [float(i) for i in line.split(' ') if i != '\n'] + ReadLandmark.append(tmp) + ReadLandmark = np.array(ReadLandmark) # + Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2 + Map_LE_B = list(np.hstack((range(17,22), range(36,42)))) + Map_RE_B = list(np.hstack((range(22,27), range(42,48)))) + Map_LE = list(range(36,42)) + Map_RE = list(range(42,48)) + Map_NO = list(range(29,36)) + Map_MO = list(range(48,68)) + + Landmarks[Landmarks>504]=504 + Landmarks[Landmarks<8]=8 + + #left eye + Mean_LE = np.mean(Landmarks[Map_LE],0) + L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1]) + L_LE1 = L_LE1 * 1.3 + L_LE2 = L_LE1 / 1.9 + L_LE_xy = L_LE1 + L_LE2 + L_LE_lt = [L_LE_xy/2, L_LE1] + L_LE_rb = [L_LE_xy/2, L_LE2] + Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int) + + #right eye + Mean_RE = np.mean(Landmarks[Map_RE],0) + L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1]) + L_RE1 = L_RE1 * 1.3 + L_RE2 = L_RE1 / 1.9 + L_RE_xy = L_RE1 + L_RE2 + L_RE_lt = [L_RE_xy/2, L_RE1] + L_RE_rb = [L_RE_xy/2, L_RE2] + Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int) + + #nose + Mean_NO = np.mean(Landmarks[Map_NO],0) + L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25 + L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1 + L_NO_xy = L_NO1 * 2 + L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2] + L_NO_rb = [L_NO_xy/2, L_NO2] + Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int) + + #mouth + Mean_MO = np.mean(Landmarks[Map_MO],0) + L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1 + MO_O = Mean_MO - L_MO + 1 + MO_T = Mean_MO + L_MO + MO_T[MO_T>510]=510 + Location_MO = np.hstack((MO_O, MO_T)).astype(int) + return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0) + + + + +def calc_mean_std_4D(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature + size = content_feat.size() + style_mean, style_std = calc_mean_std_4D(style_feat) + content_mean, content_std = calc_mean_std_4D(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True): + return nn.Sequential( + SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), + nn.LeakyReLU(0.2), + SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), + ) + + +class MSDilateBlock(nn.Module): + def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True): + super(MSDilateBlock, self).__init__() + self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias) + self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias) + self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias) + self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias) + self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias)) + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(x) + conv3 = self.conv3(x) + conv4 = self.conv4(x) + cat = torch.cat([conv1, conv2, conv3, conv4], 1) + out = self.convi(cat) + x + return out + + +class AdaptiveInstanceNorm(nn.Module): + def __init__(self, in_channel): + super().__init__() + self.norm = nn.InstanceNorm2d(in_channel) + + def forward(self, input, style): + style_mean, style_std = calc_mean_std_4D(style) + out = self.norm(input) + size = input.size() + out = style_std.expand(size) * out + style_mean.expand(size) + return out + +class NoiseInjection(nn.Module): + def __init__(self, channel): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) + def forward(self, image, noise): + if noise is None: + b, c, h, w = image.shape + noise = image.new_empty(b, 1, h, w).normal_() + return image + self.weight * noise + +class StyledUpBlock(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False): + super().__init__() + + self.noise_inject = noise_inject + if upsample: + self.conv1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + ) + else: + self.conv1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + ) + self.convup = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + ) + if self.noise_inject: + self.noise1 = NoiseInjection(out_channel) + + self.lrelu1 = nn.LeakyReLU(0.2) + + self.ScaleModel1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)) + ) + self.ShiftModel1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), + ) + + def forward(self, input, style): + out = self.conv1(input) + out = self.lrelu1(out) + Shift1 = self.ShiftModel1(style) + Scale1 = self.ScaleModel1(style) + out = out * Scale1 + Shift1 + if self.noise_inject: + out = self.noise1(out, noise=None) + outup = self.convup(out) + return outup + + +#################################################################### +###############Face Dictionary Generator +#################################################################### +def AttentionBlock(in_channel): + return nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), + ) + +class DilateResBlock(nn.Module): + def __init__(self, dim, dilation=[5,3] ): + super(DilateResBlock, self).__init__() + self.Res = nn.Sequential( + SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])), + ) + def forward(self, x): + out = x + self.Res(x) + return out + + +class KeyValue(nn.Module): + def __init__(self, indim, keydim, valdim): + super(KeyValue, self).__init__() + self.Key = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.Value = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + def forward(self, x): + return self.Key(x), self.Value(x) + +class MaskAttention(nn.Module): + def __init__(self, indim): + super(MaskAttention, self).__init__() + self.conv1 = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.conv2 = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.conv3 = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.convCat = nn.Sequential( + SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + def forward(self, x, y, z): + c1 = self.conv1(x) + c2 = self.conv2(y) + c3 = self.conv3(z) + return self.convCat(torch.cat([c1,c2,c3], dim=1)) + +class Query(nn.Module): + def __init__(self, indim, quedim): + super(Query, self).__init__() + self.Query = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + def forward(self, x): + return self.Query(x) + +def roi_align_self(input, location, target_size): + test = (target_size.item(),target_size.item()) + return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],test,mode='bilinear',align_corners=False) for i in range(input.size(0))],0) + +class FeatureExtractor(nn.Module): + def __init__(self, ngf = 64, key_scale = 4):# + super().__init__() + + self.key_scale = 4 + self.part_sizes = np.array([80,80,50,110]) # + self.feature_sizes = np.array([256,128,64]) # + + self.conv1 = nn.Sequential( + SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + ) + self.conv2 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)) + ) + self.res1 = DilateResBlock(ngf, [5,3]) + self.res2 = DilateResBlock(ngf, [5,3]) + + + self.conv3 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)), + ) + self.conv4 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)) + ) + self.res3 = DilateResBlock(ngf*2, [3,1]) + self.res4 = DilateResBlock(ngf*2, [3,1]) + + self.conv5 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)), + ) + self.conv6 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)) + ) + self.res5 = DilateResBlock(ngf*4, [1,1]) + self.res6 = DilateResBlock(ngf*4, [1,1]) + + self.LE_256_Q = Query(ngf, ngf // self.key_scale) + self.RE_256_Q = Query(ngf, ngf // self.key_scale) + self.MO_256_Q = Query(ngf, ngf // self.key_scale) + self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + + + def forward(self, img, locs): + le_location = locs[:,0,:].int().cpu().numpy() + re_location = locs[:,1,:].int().cpu().numpy() + no_location = locs[:,2,:].int().cpu().numpy() + mo_location = locs[:,3,:].int().cpu().numpy() + + + f1_0 = self.conv1(img) + f1_1 = self.res1(f1_0) + f2_0 = self.conv2(f1_1) + f2_1 = self.res2(f2_0) + + f3_0 = self.conv3(f2_1) + f3_1 = self.res3(f3_0) + f4_0 = self.conv4(f3_1) + f4_1 = self.res4(f4_0) + + f5_0 = self.conv5(f4_1) + f5_1 = self.res5(f5_0) + f6_0 = self.conv6(f5_1) + f6_1 = self.res6(f6_0) + + + ####ROI Align + le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2) + re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2) + mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2) + + le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4) + re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4) + mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4) + + le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8) + re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8) + mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8) + + + le_256_q = self.LE_256_Q(le_part_256) + re_256_q = self.RE_256_Q(re_part_256) + mo_256_q = self.MO_256_Q(mo_part_256) + + le_128_q = self.LE_128_Q(le_part_128) + re_128_q = self.RE_128_Q(re_part_128) + mo_128_q = self.MO_128_Q(mo_part_128) + + le_64_q = self.LE_64_Q(le_part_64) + re_64_q = self.RE_64_Q(re_part_64) + mo_64_q = self.MO_64_Q(mo_part_64) + + return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\ + 'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \ + 'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \ + 'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \ + 'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\ + 'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\ + 'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q} + + +class DMDNet(nn.Module): + def __init__(self, ngf = 64, banks_num = 128): + super().__init__() + self.part_sizes = np.array([80,80,50,110]) # size for 512 + self.feature_sizes = np.array([256,128,64]) # size for 512 + + self.banks_num = banks_num + self.key_scale = 4 + + self.E_lq = FeatureExtractor(key_scale = self.key_scale) + self.E_hq = FeatureExtractor(key_scale = self.key_scale) + + self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + + self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) + self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) + self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) + + self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) + self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) + self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) + + + self.LE_256_Attention = AttentionBlock(64) + self.RE_256_Attention = AttentionBlock(64) + self.MO_256_Attention = AttentionBlock(64) + + self.LE_128_Attention = AttentionBlock(128) + self.RE_128_Attention = AttentionBlock(128) + self.MO_128_Attention = AttentionBlock(128) + + self.LE_64_Attention = AttentionBlock(256) + self.RE_64_Attention = AttentionBlock(256) + self.MO_64_Attention = AttentionBlock(256) + + self.LE_256_Mask = MaskAttention(64) + self.RE_256_Mask = MaskAttention(64) + self.MO_256_Mask = MaskAttention(64) + + self.LE_128_Mask = MaskAttention(128) + self.RE_128_Mask = MaskAttention(128) + self.MO_128_Mask = MaskAttention(128) + + self.LE_64_Mask = MaskAttention(256) + self.RE_64_Mask = MaskAttention(256) + self.MO_64_Mask = MaskAttention(256) + + self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1]) + + self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) # + self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) # + self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) # + self.up4 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + nn.LeakyReLU(0.2), + UpResBlock(ngf), + UpResBlock(ngf), + SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)), + nn.Tanh() + ) + + # define generic memory, revise register_buffer to register_parameter for backward update + self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40)) + self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40)) + self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55)) + self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40)) + self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40)) + self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55)) + + + self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20)) + self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20)) + self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27)) + self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20)) + self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20)) + self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27)) + + self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10)) + self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10)) + self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13)) + self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10)) + self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10)) + self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13)) + + + def readMem(self, k, v, q): + sim = F.conv2d(q, k) + score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128 + sb,sn,sw,sh = score.size() + s_m = score.view(sb, -1).unsqueeze(1)#2*1*M + vb,vn,vw,vh = v.size() + v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h) + mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh) + max_inds = torch.argmax(score, dim=1).squeeze() + return mem_out, max_inds + + + def memorize(self, img, locs): + fs = self.E_hq(img, locs) + LE256_key, LE256_value = self.LE_256_KV(fs['le256']) + RE256_key, RE256_value = self.RE_256_KV(fs['re256']) + MO256_key, MO256_value = self.MO_256_KV(fs['mo256']) + + LE128_key, LE128_value = self.LE_128_KV(fs['le128']) + RE128_key, RE128_value = self.RE_128_KV(fs['re128']) + MO128_key, MO128_value = self.MO_128_KV(fs['mo128']) + + LE64_key, LE64_value = self.LE_64_KV(fs['le64']) + RE64_key, RE64_value = self.RE_64_KV(fs['re64']) + MO64_key, MO64_value = self.MO_64_KV(fs['mo64']) + + Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value} + Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value} + Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value} + + FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']} + FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']} + FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']} + + return Mem256, Mem128, Mem64 + + def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None): + le_256_q = fs_in['le_256_q'] + re_256_q = fs_in['re_256_q'] + mo_256_q = fs_in['mo_256_q'] + + le_128_q = fs_in['le_128_q'] + re_128_q = fs_in['re_128_q'] + mo_128_q = fs_in['mo_128_q'] + + le_64_q = fs_in['le_64_q'] + re_64_q = fs_in['re_64_q'] + mo_64_q = fs_in['mo_64_q'] + + + ####for 256 + le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q) + re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q) + mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q) + + le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q) + re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q) + mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q) + + le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q) + re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q) + mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q) + + if sp_256 is not None and sp_128 is not None and sp_64 is not None: + le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q) + re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q) + mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q) + le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g) + le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g + re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g) + re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g + mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g) + mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g + + le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q) + re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q) + mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q) + le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g) + le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g + re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g) + re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g + mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g) + mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g + + le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q) + re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q) + mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q) + le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g) + le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g + re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g) + re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g + mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g) + mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g + else: + le_256_mem = le_256_mem_g + re_256_mem = re_256_mem_g + mo_256_mem = mo_256_mem_g + le_128_mem = le_128_mem_g + re_128_mem = re_128_mem_g + mo_128_mem = mo_128_mem_g + le_64_mem = le_64_mem_g + re_64_mem = re_64_mem_g + mo_64_mem = mo_64_mem_g + + le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256']) + re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256']) + mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256']) + + ####for 128 + le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128']) + re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128']) + mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128']) + + ####for 64 + le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64']) + re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64']) + mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64']) + + + EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm} + EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm} + EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm} + Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds} + Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds} + Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds} + return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64 + + def reconstruct(self, fs_in, locs, memstar): + le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm'] + le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm'] + le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm'] + + le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256'] + re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256'] + mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256'] + + le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128'] + re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128'] + mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128'] + + le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64'] + re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64'] + mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64'] + + + le_location = locs[:,0,:] + re_location = locs[:,1,:] + mo_location = locs[:,3,:] + + # Somehow with latest Torch it doesn't like numpy wrappers anymore + + # le_location = le_location.cpu().int().numpy() + # re_location = re_location.cpu().int().numpy() + # mo_location = mo_location.cpu().int().numpy() + le_location = le_location.cpu().int() + re_location = re_location.cpu().int() + mo_location = mo_location.cpu().int() + + up_in_256 = fs_in['f256'].clone()# * 0 + up_in_128 = fs_in['f128'].clone()# * 0 + up_in_64 = fs_in['f64'].clone()# * 0 + + for i in range(fs_in['f256'].size(0)): + up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False) + up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False) + up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False) + + up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False) + up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False) + up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False) + + up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False) + up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False) + up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False) + + ms_in_64 = self.MSDilate(fs_in['f64'].clone()) + fea_up1 = self.up1(ms_in_64, up_in_64) + fea_up2 = self.up2(fea_up1, up_in_128) # + fea_up3 = self.up3(fea_up2, up_in_256) # + output = self.up4(fea_up3) # + return output + + def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None): + return self.memorize(sp_imgs, sp_locs) + + def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None): + try: + fs_in = self.E_lq(lq, loc) # low quality images + except Exception as e: + print(e) + + GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in) + GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64]) + if sp_256 is not None and sp_128 is not None and sp_64 is not None: + GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64) + GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64]) + else: + GSOut = None + return GeOut, GSOut + +class UpResBlock(nn.Module): + def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d): + super(UpResBlock, self).__init__() + self.Model = nn.Sequential( + SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + ) + def forward(self, x): + out = x + self.Model(x) + return out diff --git a/roop-unleashed/roop/processors/Enhance_GFPGAN.py b/roop-unleashed/roop/processors/Enhance_GFPGAN.py new file mode 100644 index 0000000000000000000000000000000000000000..ca61cb70f302712ca1c6f54ee06aad9ed0f33f0c --- /dev/null +++ b/roop-unleashed/roop/processors/Enhance_GFPGAN.py @@ -0,0 +1,77 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +# THREAD_LOCK = threading.Lock() + + +class Enhance_GFPGAN(): + plugin_options:dict = None + + model_gfpgan = None + name = None + devicename = None + + processorname = 'gfpgan' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_gfpgan is None: + model_path = resolve_relative_path('../models/GFPGANv1.4.onnx') + self.model_gfpgan = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + + self.name = self.model_gfpgan.get_inputs()[0].name + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + io_binding = self.model_gfpgan.io_binding() + io_binding.bind_cpu_input("input", temp_frame) + io_binding.bind_output("1288", self.devicename) + self.model_gfpgan.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + + # post-process + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + self.model_gfpgan = None + + + + + + + + + + + diff --git a/roop-unleashed/roop/processors/Enhance_GPEN.py b/roop-unleashed/roop/processors/Enhance_GPEN.py new file mode 100644 index 0000000000000000000000000000000000000000..9821e70534e3bddcd2a932548fd7b9250d85a41a --- /dev/null +++ b/roop-unleashed/roop/processors/Enhance_GPEN.py @@ -0,0 +1,63 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +class Enhance_GPEN(): + plugin_options:dict = None + + model_gpen = None + name = None + devicename = None + + processorname = 'gpen' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_gpen is None: + model_path = resolve_relative_path('../models/GPEN-BFR-512.onnx') + self.model_gpen = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + + self.name = self.model_gpen.get_inputs()[0].name + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + io_binding = self.model_gpen.io_binding() + io_binding.bind_cpu_input("input", temp_frame) + io_binding.bind_output("output", self.devicename) + self.model_gpen.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + + # post-process + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + self.model_gpen = None diff --git a/roop-unleashed/roop/processors/Enhance_RestoreFormerPPlus.py b/roop-unleashed/roop/processors/Enhance_RestoreFormerPPlus.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d71034573cf1e63be77a4b9acafc854f189536 --- /dev/null +++ b/roop-unleashed/roop/processors/Enhance_RestoreFormerPPlus.py @@ -0,0 +1,64 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + +class Enhance_RestoreFormerPPlus(): + plugin_options:dict = None + model_restoreformerpplus = None + devicename = None + name = None + + processorname = 'restoreformer++' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_restoreformerpplus is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + model_path = resolve_relative_path('../models/restoreformer_plus_plus.onnx') + self.model_restoreformerpplus = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_restoreformerpplus.get_inputs() + model_outputs = self.model_restoreformerpplus.get_outputs() + self.io_binding = self.model_restoreformerpplus.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) # .astype(np.float32) + self.model_restoreformerpplus.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + del self.model_restoreformerpplus + self.model_restoreformerpplus = None + del self.io_binding + self.io_binding = None + diff --git a/roop-unleashed/roop/processors/FaceSwapInsightFace.py b/roop-unleashed/roop/processors/FaceSwapInsightFace.py new file mode 100644 index 0000000000000000000000000000000000000000..34290899fed8f74b4e7bc7aaf2909779dfb4d639 --- /dev/null +++ b/roop-unleashed/roop/processors/FaceSwapInsightFace.py @@ -0,0 +1,69 @@ +import roop.globals +import cv2 +import numpy as np +import onnx +import onnxruntime + +from roop.typing import Face, Frame +from roop.utilities import resolve_relative_path + + + +class FaceSwapInsightFace(): + plugin_options:dict = None + model_swap_insightface = None + + processorname = 'faceswap' + type = 'swap' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_swap_insightface is None: + model_path = resolve_relative_path('../models/inswapper_128.onnx') + graph = onnx.load(model_path).graph + self.emap = onnx.numpy_helper.to_array(graph.initializer[-1]) + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + self.input_mean = 0.0 + self.input_std = 255.0 + #cuda_options = {"arena_extend_strategy": "kSameAsRequested", 'cudnn_conv_algo_search': 'DEFAULT'} + sess_options = onnxruntime.SessionOptions() + sess_options.enable_cpu_mem_arena = False + self.model_swap_insightface = onnxruntime.InferenceSession(model_path, sess_options, providers=roop.globals.execution_providers) + + + + def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: + blob = cv2.dnn.blobFromImage(temp_frame, 1.0 / self.input_std, (128, 128), + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + latent = source_face.normed_embedding.reshape((1,-1)) + latent = np.dot(latent, self.emap) + latent /= np.linalg.norm(latent) + io_binding = self.model_swap_insightface.io_binding() + io_binding.bind_cpu_input("target", blob) + io_binding.bind_cpu_input("source", latent) + io_binding.bind_output("output", self.devicename) + self.model_swap_insightface.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu()[0] + img_fake = ort_outs.transpose((0,2,3,1))[0] + return np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] + + + img_fake, M = self.model_swap_insightface.get(temp_frame, target_face, source_face, paste_back=False) + # target_face.matrix = M + # return img_fake + + + def Release(self): + del self.model_swap_insightface + self.model_swap_insightface = None + + + + + + diff --git a/roop-unleashed/roop/processors/Frame_Colorizer.py b/roop-unleashed/roop/processors/Frame_Colorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..372f81870b6c47f543707e8eefff3a474532b493 --- /dev/null +++ b/roop-unleashed/roop/processors/Frame_Colorizer.py @@ -0,0 +1,70 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + +class Frame_Colorizer(): + plugin_options:dict = None + model_colorizer = None + devicename = None + prev_type = None + + processorname = 'deoldify' + type = 'frame_colorizer' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]: + self.Release() + self.prev_type = self.plugin_options["subtype"] + if self.model_colorizer is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + if self.prev_type == "deoldify_artistic": + model_path = resolve_relative_path('../models/Frame/deoldify_artistic.onnx') + elif self.prev_type == "deoldify_stable": + model_path = resolve_relative_path('../models/Frame/deoldify_stable.onnx') + + onnxruntime.set_default_logger_severity(3) + self.model_colorizer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_colorizer.get_inputs() + model_outputs = self.model_colorizer.get_outputs() + self.io_binding = self.model_colorizer.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, input_frame: Frame) -> Frame: + temp_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2GRAY) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) + temp_frame = cv2.resize(temp_frame, (256, 256)) + temp_frame = temp_frame.transpose((2, 0, 1)) + temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32) + self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) + self.model_colorizer.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + colorized_frame = result.transpose(1, 2, 0) + colorized_frame = cv2.resize(colorized_frame, (input_frame.shape[1], input_frame.shape[0])) + temp_blue_channel, _, _ = cv2.split(input_frame) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB) + _, color_green_channel, color_red_channel = cv2.split(colorized_frame) + colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel)) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR) + return colorized_frame.astype(np.uint8) + + + def Release(self): + del self.model_colorizer + self.model_colorizer = None + del self.io_binding + self.io_binding = None + diff --git a/roop-unleashed/roop/processors/Frame_Filter.py b/roop-unleashed/roop/processors/Frame_Filter.py new file mode 100644 index 0000000000000000000000000000000000000000..b1405c329167a4e7f4f926ade5cf06ab6166466f --- /dev/null +++ b/roop-unleashed/roop/processors/Frame_Filter.py @@ -0,0 +1,105 @@ +import cv2 +import numpy as np + +from roop.typing import Frame + +class Frame_Filter(): + processorname = 'generic_filter' + type = 'frame_processor' + + plugin_options:dict = None + + c64_palette = np.array([ + [0, 0, 0], + [255, 255, 255], + [0x81, 0x33, 0x38], + [0x75, 0xce, 0xc8], + [0x8e, 0x3c, 0x97], + [0x56, 0xac, 0x4d], + [0x2e, 0x2c, 0x9b], + [0xed, 0xf1, 0x71], + [0x8e, 0x50, 0x29], + [0x55, 0x38, 0x00], + [0xc4, 0x6c, 0x71], + [0x4a, 0x4a, 0x4a], + [0x7b, 0x7b, 0x7b], + [0xa9, 0xff, 0x9f], + [0x70, 0x6d, 0xeb], + [0xb2, 0xb2, 0xb2] + ]) + + + def RenderC64Screen(self, image): + # Simply round the color values to the nearest color in the palette + image = cv2.resize(image,(320,200)) + palette = self.c64_palette / 255.0 # Normalize palette + img_normalized = image / 255.0 # Normalize image + + # Calculate the index in the palette that is closest to each pixel in the image + indices = np.sqrt(((img_normalized[:, :, None, :] - palette[None, None, :, :]) ** 2).sum(axis=3)).argmin(axis=2) + # Map the image to the palette colors + mapped_image = palette[indices] + return (mapped_image * 255).astype(np.uint8) # Denormalize and return the image + + + def RenderDetailEnhance(self, image): + return cv2.detailEnhance(image) + + def RenderStylize(self, image): + return cv2.stylization(image) + + def RenderPencilSketch(self, image): + imgray, imout = cv2.pencilSketch(image, sigma_s=60, sigma_r=0.07, shade_factor=0.05) + return imout + + def RenderCartoon(self, image): + numDownSamples = 2 # number of downscaling steps + numBilateralFilters = 7 # number of bilateral filtering steps + + img_color = image + for _ in range(numDownSamples): + img_color = cv2.pyrDown(img_color) + for _ in range(numBilateralFilters): + img_color = cv2.bilateralFilter(img_color, 9, 9, 7) + for _ in range(numDownSamples): + img_color = cv2.pyrUp(img_color) + img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + img_blur = cv2.medianBlur(img_gray, 7) + img_edge = cv2.adaptiveThreshold(img_blur, 255, + cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2) + img_edge = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2RGB) + if img_color.shape != image.shape: + img_color = cv2.resize(img_color, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR) + if img_color.shape != img_edge.shape: + img_edge = cv2.resize(img_edge, (img_color.shape[1], img_color.shape[0]), interpolation=cv2.INTER_LINEAR) + return cv2.bitwise_and(img_color, img_edge) + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + self.plugin_options = plugin_options + + def Run(self, temp_frame: Frame) -> Frame: + subtype = self.plugin_options["subtype"] + if subtype == "stylize": + return self.RenderStylize(temp_frame).astype(np.uint8) + if subtype == "detailenhance": + return self.RenderDetailEnhance(temp_frame).astype(np.uint8) + if subtype == "pencil": + return self.RenderPencilSketch(temp_frame).astype(np.uint8) + if subtype == "cartoon": + return self.RenderCartoon(temp_frame).astype(np.uint8) + if subtype == "C64": + return self.RenderC64Screen(temp_frame).astype(np.uint8) + + + def Release(self): + pass + + def getProcessedResolution(self, width, height): + if self.plugin_options["subtype"] == "C64": + return (320,200) + return None + diff --git a/roop-unleashed/roop/processors/Frame_Masking.py b/roop-unleashed/roop/processors/Frame_Masking.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4e77fec51854fc67c5274193665fd3555c24bb --- /dev/null +++ b/roop-unleashed/roop/processors/Frame_Masking.py @@ -0,0 +1,71 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + +class Frame_Masking(): + plugin_options:dict = None + model_masking = None + devicename = None + name = None + + processorname = 'removebg' + type = 'frame_masking' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_masking is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"] + self.devicename = self.devicename.replace('mps', 'cpu') + model_path = resolve_relative_path('../models/Frame/isnet-general-use.onnx') + self.model_masking = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_masking.get_inputs() + model_outputs = self.model_masking.get_outputs() + self.io_binding = self.model_masking.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, temp_frame: Frame) -> Frame: + # Pre process:Resize, BGR->RGB, float32 cast + input_image = cv2.resize(temp_frame, (1024, 1024)) + input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) + mean = [0.5, 0.5, 0.5] + std = [1.0, 1.0, 1.0] + input_image = (input_image / 255.0 - mean) / std + input_image = input_image.transpose(2, 0, 1) + input_image = np.expand_dims(input_image, axis=0) + input_image = input_image.astype('float32') + + self.io_binding.bind_cpu_input(self.model_inputs[0].name, input_image) + self.model_masking.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + # Post process:squeeze, Sigmoid, Normarize, uint8 cast + mask = np.squeeze(result[0]) + min_value = np.min(mask) + max_value = np.max(mask) + mask = (mask - min_value) / (max_value - min_value) + #mask = np.where(mask < score_th, 0, 1) + #mask *= 255 + mask = cv2.resize(mask, (temp_frame.shape[1], temp_frame.shape[0]), interpolation=cv2.INTER_LINEAR) + mask = np.reshape(mask, [mask.shape[0],mask.shape[1],1]) + result = mask * temp_frame.astype(np.float32) + return result.astype(np.uint8) + + + + def Release(self): + del self.model_masking + self.model_masking = None + del self.io_binding + self.io_binding = None + diff --git a/roop-unleashed/roop/processors/Frame_Upscale.py b/roop-unleashed/roop/processors/Frame_Upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..e323e98eee7cea6662a6426eb12ebc6a8b753974 --- /dev/null +++ b/roop-unleashed/roop/processors/Frame_Upscale.py @@ -0,0 +1,131 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals +import threading + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + +class Frame_Upscale(): + plugin_options:dict = None + model_upscale = None + devicename = None + prev_type = None + + processorname = 'upscale' + type = 'frame_enhancer' + + THREAD_LOCK_UPSCALE = threading.Lock() + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]: + self.Release() + self.prev_type = self.plugin_options["subtype"] + if self.model_upscale is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + if self.prev_type == "esrganx4": + model_path = resolve_relative_path('../models/Frame/real_esrgan_x4.onnx') + self.scale = 4 + elif self.prev_type == "esrganx2": + model_path = resolve_relative_path('../models/Frame/real_esrgan_x2.onnx') + self.scale = 2 + elif self.prev_type == "lsdirx4": + model_path = resolve_relative_path('../models/Frame/lsdir_x4.onnx') + self.scale = 4 + + self.model_upscale = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_upscale.get_inputs() + model_outputs = self.model_upscale.get_outputs() + self.io_binding = self.model_upscale.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def getProcessedResolution(self, width, height): + return (width * self.scale, height * self.scale) + +# borrowed from facefusion -> https://github.com/facefusion/facefusion + def prepare_tile_frame(self, tile_frame : Frame) -> Frame: + tile_frame = np.expand_dims(tile_frame[:, :, ::-1], axis = 0) + tile_frame = tile_frame.transpose(0, 3, 1, 2) + tile_frame = tile_frame.astype(np.float32) / 255 + return tile_frame + + + def normalize_tile_frame(self, tile_frame : Frame) -> Frame: + tile_frame = tile_frame.transpose(0, 2, 3, 1).squeeze(0) * 255 + tile_frame = tile_frame.clip(0, 255).astype(np.uint8)[:, :, ::-1] + return tile_frame + + def create_tile_frames(self, input_frame : Frame, size): + input_frame = np.pad(input_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0))) + tile_width = size[0] - 2 * size[2] + pad_size_bottom = size[2] + tile_width - input_frame.shape[0] % tile_width + pad_size_right = size[2] + tile_width - input_frame.shape[1] % tile_width + pad_vision_frame = np.pad(input_frame, ((size[2], pad_size_bottom), (size[2], pad_size_right), (0, 0))) + pad_height, pad_width = pad_vision_frame.shape[:2] + row_range = range(size[2], pad_height - size[2], tile_width) + col_range = range(size[2], pad_width - size[2], tile_width) + tile_frames = [] + + for row_frame in row_range: + top = row_frame - size[2] + bottom = row_frame + size[2] + tile_width + for column_vision_frame in col_range: + left = column_vision_frame - size[2] + right = column_vision_frame + size[2] + tile_width + tile_frames.append(pad_vision_frame[top:bottom, left:right, :]) + return tile_frames, pad_width, pad_height + + + def merge_tile_frames(self, tile_frames, temp_width : int, temp_height : int, pad_width : int, pad_height : int, size) -> Frame: + merge_frame = np.zeros((pad_height, pad_width, 3)).astype(np.uint8) + tile_width = tile_frames[0].shape[1] - 2 * size[2] + tiles_per_row = min(pad_width // tile_width, len(tile_frames)) + + for index, tile_frame in enumerate(tile_frames): + tile_frame = tile_frame[size[2]:-size[2], size[2]:-size[2]] + row_index = index // tiles_per_row + col_index = index % tiles_per_row + top = row_index * tile_frame.shape[0] + bottom = top + tile_frame.shape[0] + left = col_index * tile_frame.shape[1] + right = left + tile_frame.shape[1] + merge_frame[top:bottom, left:right, :] = tile_frame + merge_frame = merge_frame[size[1] : size[1] + temp_height, size[1]: size[1] + temp_width, :] + return merge_frame + + + def Run(self, temp_frame: Frame) -> Frame: + size = (128, 8, 2) + temp_height, temp_width = temp_frame.shape[:2] + upscale_tile_frames, pad_width, pad_height = self.create_tile_frames(temp_frame, size) + + for index, tile_frame in enumerate(upscale_tile_frames): + tile_frame = self.prepare_tile_frame(tile_frame) + with self.THREAD_LOCK_UPSCALE: + self.io_binding.bind_cpu_input(self.model_inputs[0].name, tile_frame) + self.model_upscale.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0] + upscale_tile_frames[index] = self.normalize_tile_frame(result) + final_frame = self.merge_tile_frames(upscale_tile_frames, temp_width * self.scale + , temp_height * self.scale + , pad_width * self.scale, pad_height * self.scale + , (size[0] * self.scale, size[1] * self.scale, size[2] * self.scale)) + return final_frame.astype(np.uint8) + + + + def Release(self): + del self.model_upscale + self.model_upscale = None + del self.io_binding + self.io_binding = None + diff --git a/roop-unleashed/roop/processors/Mask_Clip2Seg.py b/roop-unleashed/roop/processors/Mask_Clip2Seg.py new file mode 100644 index 0000000000000000000000000000000000000000..5df3b3e37ea10eb2440828a08e129d8c62f98086 --- /dev/null +++ b/roop-unleashed/roop/processors/Mask_Clip2Seg.py @@ -0,0 +1,94 @@ +import cv2 +import numpy as np +import torch +import threading +from torchvision import transforms +from clip.clipseg import CLIPDensePredT +import numpy as np + +from roop.typing import Frame + +THREAD_LOCK_CLIP = threading.Lock() + + +class Mask_Clip2Seg(): + plugin_options:dict = None + model_clip = None + + processorname = 'clip2seg' + type = 'mask' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_clip is None: + self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True) + self.model_clip.eval(); + self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False) + + device = torch.device(self.plugin_options["devicename"]) + self.model_clip.to(device) + + + def Run(self, img1, keywords:str) -> Frame: + if keywords is None or len(keywords) < 1 or img1 is None: + return img1 + + source_image_small = cv2.resize(img1, (256,256)) + + img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32) + mask_border = 1 + l = 0 + t = 0 + r = 1 + b = 1 + + mask_blur = 5 + clip_blur = 5 + + img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)), + (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1) + img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0) + img_mask /= 255 + + + input_image = source_image_small + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + transforms.Resize((256, 256)), + ]) + img = transform(input_image).unsqueeze(0) + + thresh = 0.5 + prompts = keywords.split(',') + with THREAD_LOCK_CLIP: + with torch.no_grad(): + preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0] + clip_mask = torch.sigmoid(preds[0][0]) + for i in range(len(prompts)-1): + clip_mask += torch.sigmoid(preds[i+1][0]) + + clip_mask = clip_mask.data.cpu().numpy() + np.clip(clip_mask, 0, 1) + + clip_mask[clip_mask>thresh] = 1.0 + clip_mask[clip_mask<=thresh] = 0.0 + kernel = np.ones((5, 5), np.float32) + clip_mask = cv2.dilate(clip_mask, kernel, iterations=1) + clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0) + + img_mask *= clip_mask + img_mask[img_mask<0.0] = 0.0 + return img_mask + + + + def Release(self): + self.model_clip = None + diff --git a/roop-unleashed/roop/processors/Mask_XSeg.py b/roop-unleashed/roop/processors/Mask_XSeg.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8e87741c9aa99cde84aa20566bb8c3db548fe2 --- /dev/null +++ b/roop-unleashed/roop/processors/Mask_XSeg.py @@ -0,0 +1,60 @@ +import numpy as np +import cv2 +import onnxruntime +import threading +import roop.globals + +from roop.typing import Frame +from roop.utilities import resolve_relative_path + +THREAD_LOCK_CLIP = threading.Lock() + + +class Mask_XSeg(): + plugin_options:dict = None + + model_xseg = None + + processorname = 'mask_xseg' + type = 'mask' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_xseg is None: + model_path = resolve_relative_path('../models/xseg.onnx') + onnxruntime.set_default_logger_severity(3) + self.model_xseg = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_xseg.get_inputs() + self.model_outputs = self.model_xseg.get_outputs() + + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + + + def Run(self, img1, keywords:str) -> Frame: + temp_frame = cv2.resize(img1, (256, 256), cv2.INTER_CUBIC) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = temp_frame[None, ...] + io_binding = self.model_xseg.io_binding() + io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) + io_binding.bind_output(self.model_outputs[0].name, self.devicename) + self.model_xseg.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + result = np.clip(result, 0, 1.0) + result[result < 0.1] = 0 + # invert values to mask areas to keep + result = 1.0 - result + return result + + + def Release(self): + del self.model_xseg + self.model_xseg = None + + diff --git a/pretrained_weights/.huggingface/download/image_encoder/config.json.lock b/roop-unleashed/roop/processors/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pretrained_weights/.huggingface/download/image_encoder/config.json.lock rename to roop-unleashed/roop/processors/__init__.py diff --git a/roop-unleashed/roop/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc b/roop-unleashed/roop/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08b2a9fd8cf366f9f56c0612b91034e1b0025631 Binary files /dev/null and b/roop-unleashed/roop/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc differ diff --git a/roop-unleashed/roop/processors/__pycache__/Mask_XSeg.cpython-310.pyc b/roop-unleashed/roop/processors/__pycache__/Mask_XSeg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e23db6f39b0eea20125f3a99ff8aba1116df79 Binary files /dev/null and b/roop-unleashed/roop/processors/__pycache__/Mask_XSeg.cpython-310.pyc differ diff --git a/roop-unleashed/roop/processors/__pycache__/__init__.cpython-310.pyc b/roop-unleashed/roop/processors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b477fce10f565769459daa9c46b19ba286c59529 Binary files /dev/null and b/roop-unleashed/roop/processors/__pycache__/__init__.cpython-310.pyc differ diff --git a/roop-unleashed/roop/template_parser.py b/roop-unleashed/roop/template_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..a51113b69830119fc84fd15c2a428321ac1d8010 --- /dev/null +++ b/roop-unleashed/roop/template_parser.py @@ -0,0 +1,23 @@ +import re +from datetime import datetime + +template_functions = { + "timestamp": lambda data: str(int(datetime.now().timestamp())), + "i": lambda data: data.get("index", False), + "file": lambda data: data.get("file", False), + "date": lambda data: datetime.now().strftime("%Y-%m-%d"), + "time": lambda data: datetime.now().strftime("%H-%M-%S"), +} + + +def parse(text: str, data: dict): + pattern = r"\{([^}]+)\}" + + matches = re.findall(pattern, text) + + for match in matches: + replacement = template_functions[match](data) + if replacement is not False: + text = text.replace(f"{{{match}}}", replacement) + + return text diff --git a/roop-unleashed/roop/typing.py b/roop-unleashed/roop/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..263f1b5b0331332dfab9f682438b364c612cfdf8 --- /dev/null +++ b/roop-unleashed/roop/typing.py @@ -0,0 +1,9 @@ +from typing import Any + +from insightface.app.common import Face +from roop.FaceSet import FaceSet +import numpy + +Face = Face +FaceSet = FaceSet +Frame = numpy.ndarray[Any, Any] diff --git a/roop-unleashed/roop/util_ffmpeg.py b/roop-unleashed/roop/util_ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..8b8c9a978f2acfd647c5e1088c0264e9193b68be --- /dev/null +++ b/roop-unleashed/roop/util_ffmpeg.py @@ -0,0 +1,112 @@ + +import os +import subprocess +import roop.globals +import roop.utilities as util + +from typing import List, Any + +def run_ffmpeg(args: List[str]) -> bool: + commands = ['ffmpeg', '-hide_banner', '-hwaccel', 'auto', '-y', '-loglevel', roop.globals.log_level] + commands.extend(args) + print("Running ffmpeg") + try: + subprocess.check_output(commands, stderr=subprocess.STDOUT) + return True + except Exception as e: + print("Running ffmpeg failed! Commandline:") + print(" ".join(map(str, commands))) # Ensure all elements are strings + print(e) + return False + + + +def cut_video(original_video: str, cut_video: str, start_frame: int, end_frame: int, reencode: bool): + fps = util.detect_fps(original_video) + start_time = start_frame / fps + num_frames = end_frame - start_frame + + if reencode: + run_ffmpeg(['-ss', format(start_time, ".2f"), '-i', original_video, '-c:v', roop.globals.video_encoder, '-c:a', 'aac', '-frames:v', str(num_frames), cut_video]) + else: + run_ffmpeg(['-ss', format(start_time, ".2f"), '-i', original_video, '-frames:v', str(num_frames), '-c:v' ,'copy','-c:a' ,'copy', cut_video]) + +def join_videos(videos: List[str], dest_filename: str, simple: bool): + if simple: + txtfilename = util.resolve_relative_path('../temp') + txtfilename = os.path.join(txtfilename, 'joinvids.txt') + with open(txtfilename, "w", encoding="utf-8") as f: + for v in videos: + v = v.replace('\\', '/') + f.write(f"file {v}\n") + commands = ['-f', 'concat', '-safe', '0', '-i', f'{txtfilename}', '-vcodec', 'copy', f'{dest_filename}'] + run_ffmpeg(commands) + + else: + inputs = [] + filter = '' + for i,v in enumerate(videos): + inputs.append('-i') + inputs.append(v) + filter += f'[{i}:v:0][{i}:a:0]' + run_ffmpeg([" ".join(inputs), '-filter_complex', f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', '-map', '"[outv]"', '-map', '"[outa]"', dest_filename]) + + # filter += f'[{i}:v:0][{i}:a:0]' + # run_ffmpeg([" ".join(inputs), '-filter_complex', f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', '-map', '"[outv]"', '-map', '"[outa]"', dest_filename]) + + + +def extract_frames(target_path : str, trim_frame_start, trim_frame_end, fps : float) -> bool: + util.create_temp(target_path) + temp_directory_path = util.get_temp_directory_path(target_path) + commands = ['-i', target_path, '-q:v', '1', '-pix_fmt', 'rgb24', ] + if trim_frame_start is not None and trim_frame_end is not None: + commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ':end_frame=' + str(trim_frame_end) + ',fps=' + str(fps) ]) + commands.extend(['-vsync', '0', os.path.join(temp_directory_path, '%06d.' + roop.globals.CFG.output_image_format)]) + return run_ffmpeg(commands) + + +def create_video(target_path: str, dest_filename: str, fps: float = 24.0, temp_directory_path: str = None) -> None: + if temp_directory_path is None: + temp_directory_path = util.get_temp_directory_path(target_path) + print("dest file name is " + dest_filename) + run_ffmpeg(['-r', str(fps), '-i', os.path.join(temp_directory_path, f'%06d.{roop.globals.CFG.output_image_format}'), '-c:v', roop.globals.video_encoder, '-crf', str(roop.globals.video_quality), '-pix_fmt', 'yuv420p', '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1', '-y', dest_filename]) + return dest_filename + + +def create_gif_from_video(video_path: str, gif_path): + from roop.capturer import get_video_frame + + fps = util.detect_fps(video_path) + frame = get_video_frame(video_path) + + run_ffmpeg(['-i', video_path, '-vf', f'fps={fps},scale={frame.shape[0]}:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse', '-loop', '0', gif_path]) + + +def restore_audio(intermediate_video: str, original_video: str, trim_frame_start, trim_frame_end, final_video : str) -> None: + fps = util.detect_fps(original_video) + commands = [ '-i', intermediate_video ] + if trim_frame_start is None and trim_frame_end is None: + commands.extend([ '-c:a', 'copy' ]) + else: + # if trim_frame_start is not None: + # start_time = trim_frame_start / fps + # commands.extend([ '-ss', format(start_time, ".2f")]) + # else: + # commands.extend([ '-ss', '0' ]) + # if trim_frame_end is not None: + # end_time = trim_frame_end / fps + # commands.extend([ '-to', format(end_time, ".2f")]) + # commands.extend([ '-c:a', 'aac' ]) + if trim_frame_start is not None: + start_time = trim_frame_start / fps + commands.extend([ '-ss', format(start_time, ".2f")]) + else: + commands.extend([ '-ss', '0' ]) + if trim_frame_end is not None: + end_time = trim_frame_end / fps + commands.extend([ '-to', format(end_time, ".2f")]) + commands.extend([ '-i', original_video, "-c", "copy" ]) + + commands.extend([ '-map', '0:v:0', '-map', '1:a:0?', '-shortest', final_video ]) + run_ffmpeg(commands) diff --git a/roop-unleashed/roop/utilities.py b/roop-unleashed/roop/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7a41c079e0d71e81c9d3326100a598471c79a0 --- /dev/null +++ b/roop-unleashed/roop/utilities.py @@ -0,0 +1,339 @@ +import glob +import mimetypes +import os +import platform +import shutil +import ssl +import subprocess +import sys +import urllib +import torch +import gradio +import tempfile +import cv2 +import zipfile +import traceback + +from pathlib import Path +from typing import List, Any +from tqdm import tqdm +from scipy.spatial import distance + +import roop.template_parser as template_parser + +import roop.globals + +TEMP_FILE = "temp.mp4" +TEMP_DIRECTORY = "temp" + +# monkey patch ssl for mac +if platform.system().lower() == "darwin": + ssl._create_default_https_context = ssl._create_unverified_context + + +# https://github.com/facefusion/facefusion/blob/master/facefusion +def detect_fps(target_path: str) -> float: + fps = 24.0 + cap = cv2.VideoCapture(target_path) + if cap.isOpened(): + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + return fps + + +# Gradio wants Images in RGB +def convert_to_gradio(image): + if image is None: + return None + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +def sort_filenames_ignore_path(filenames): + """Sorts a list of filenames containing a complete path by their filename, + while retaining their original path. + + Args: + filenames: A list of filenames containing a complete path. + + Returns: + A sorted list of filenames containing a complete path. + """ + filename_path_tuples = [ + (os.path.split(filename)[1], filename) for filename in filenames + ] + sorted_filename_path_tuples = sorted(filename_path_tuples, key=lambda x: x[0]) + return [ + filename_path_tuple[1] for filename_path_tuple in sorted_filename_path_tuples + ] + + +def sort_rename_frames(path: str): + filenames = os.listdir(path) + filenames.sort() + for i in range(len(filenames)): + of = os.path.join(path, filenames[i]) + newidx = i + 1 + new_filename = os.path.join( + path, f"{newidx:06d}." + roop.globals.CFG.output_image_format + ) + os.rename(of, new_filename) + + +def get_temp_frame_paths(target_path: str) -> List[str]: + temp_directory_path = get_temp_directory_path(target_path) + return glob.glob( + ( + os.path.join( + glob.escape(temp_directory_path), + f"*.{roop.globals.CFG.output_image_format}", + ) + ) + ) + + +def get_temp_directory_path(target_path: str) -> str: + target_name, _ = os.path.splitext(os.path.basename(target_path)) + target_directory_path = os.path.dirname(target_path) + return os.path.join(target_directory_path, TEMP_DIRECTORY, target_name) + + +def get_temp_output_path(target_path: str) -> str: + temp_directory_path = get_temp_directory_path(target_path) + return os.path.join(temp_directory_path, TEMP_FILE) + + +def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any: + if source_path and target_path: + source_name, _ = os.path.splitext(os.path.basename(source_path)) + target_name, target_extension = os.path.splitext(os.path.basename(target_path)) + if os.path.isdir(output_path): + return os.path.join( + output_path, source_name + "-" + target_name + target_extension + ) + return output_path + + +def get_destfilename_from_path( + srcfilepath: str, destfilepath: str, extension: str +) -> str: + fn, ext = os.path.splitext(os.path.basename(srcfilepath)) + if "." in extension: + return os.path.join(destfilepath, f"{fn}{extension}") + return os.path.join(destfilepath, f"{fn}{extension}{ext}") + + +def replace_template(file_path: str, index: int = 0) -> str: + fn, ext = os.path.splitext(os.path.basename(file_path)) + + # Remove the "__temp" placeholder that was used as a temporary filename + fn = fn.replace("__temp", "") + + template = roop.globals.CFG.output_template + replaced_filename = template_parser.parse( + template, {"index": str(index), "file": fn} + ) + + return os.path.join(roop.globals.output_path, f"{replaced_filename}{ext}") + + +def create_temp(target_path: str) -> None: + temp_directory_path = get_temp_directory_path(target_path) + Path(temp_directory_path).mkdir(parents=True, exist_ok=True) + + +def move_temp(target_path: str, output_path: str) -> None: + temp_output_path = get_temp_output_path(target_path) + if os.path.isfile(temp_output_path): + if os.path.isfile(output_path): + os.remove(output_path) + shutil.move(temp_output_path, output_path) + + +def clean_temp(target_path: str) -> None: + temp_directory_path = get_temp_directory_path(target_path) + parent_directory_path = os.path.dirname(temp_directory_path) + if not roop.globals.keep_frames and os.path.isdir(temp_directory_path): + shutil.rmtree(temp_directory_path) + if os.path.exists(parent_directory_path) and not os.listdir(parent_directory_path): + os.rmdir(parent_directory_path) + + +def delete_temp_frames(filename: str) -> None: + dir = os.path.dirname(os.path.dirname(filename)) + shutil.rmtree(dir) + + +def has_image_extension(image_path: str) -> bool: + return image_path.lower().endswith(("png", "jpg", "jpeg", "webp")) + + +def has_extension(filepath: str, extensions: List[str]) -> bool: + return filepath.lower().endswith(tuple(extensions)) + + +def is_image(image_path: str) -> bool: + if image_path and os.path.isfile(image_path): + mimetype, _ = mimetypes.guess_type(image_path) + return bool(mimetype and mimetype.startswith("image/")) + return False + + +def is_video(video_path: str) -> bool: + if video_path and os.path.isfile(video_path): + mimetype, _ = mimetypes.guess_type(video_path) + return bool(mimetype and mimetype.startswith("video/")) + return False + + +def conditional_download(download_directory_path: str, urls: List[str]) -> None: + if not os.path.exists(download_directory_path): + os.makedirs(download_directory_path) + for url in urls: + download_file_path = os.path.join( + download_directory_path, os.path.basename(url) + ) + if not os.path.exists(download_file_path): + request = urllib.request.urlopen(url) # type: ignore[attr-defined] + total = int(request.headers.get("Content-Length", 0)) + with tqdm( + total=total, + desc=f"Downloading {url}", + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as progress: + urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined] + + +def get_local_files_from_folder(folder: str) -> List[str]: + if not os.path.exists(folder) or not os.path.isdir(folder): + return None + files = [ + os.path.join(folder, f) + for f in os.listdir(folder) + if os.path.isfile(os.path.join(folder, f)) + ] + return files + + +def resolve_relative_path(path: str) -> str: + return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) + + +def get_device() -> str: + if len(roop.globals.execution_providers) < 1: + roop.globals.execution_providers = ["CPUExecutionProvider"] + + prov = roop.globals.execution_providers[0] + if "CoreMLExecutionProvider" in prov: + return "mps" + if "CUDAExecutionProvider" in prov or "ROCMExecutionProvider" in prov: + return "cuda" + if "OpenVINOExecutionProvider" in prov: + return "mkl" + return "cpu" + + +def str_to_class(module_name, class_name) -> Any: + from importlib import import_module + + class_ = None + try: + module_ = import_module(module_name) + try: + class_ = getattr(module_, class_name)() + except AttributeError: + print(f"Class {class_name} does not exist") + except ImportError: + print(f"Module {module_name} does not exist") + return class_ + +def is_installed(name:str) -> bool: + return shutil.which(name); + +# Taken from https://stackoverflow.com/a/68842705 +def get_platform() -> str: + if sys.platform == "linux": + try: + proc_version = open("/proc/version").read() + if "Microsoft" in proc_version: + return "wsl" + except: + pass + return sys.platform + +def open_with_default_app(filename:str): + if filename == None: + return + platform = get_platform() + if platform == "darwin": + subprocess.call(("open", filename)) + elif platform in ["win64", "win32"]: os.startfile(filename.replace("/", "\\")) + elif platform == "wsl": + subprocess.call("cmd.exe /C start".split() + [filename]) + else: # linux variants + subprocess.call("xdg-open", filename) + + +def prepare_for_batch(target_files) -> str: + print("Preparing temp files") + tempfolder = os.path.join(tempfile.gettempdir(), "rooptmp") + if os.path.exists(tempfolder): + shutil.rmtree(tempfolder) + Path(tempfolder).mkdir(parents=True, exist_ok=True) + for f in target_files: + newname = os.path.basename(f.name) + shutil.move(f.name, os.path.join(tempfolder, newname)) + return tempfolder + + +def zip(files, zipname): + with zipfile.ZipFile(zipname, "w") as zip_file: + for f in files: + zip_file.write(f, os.path.basename(f)) + + +def unzip(zipfilename: str, target_path: str): + with zipfile.ZipFile(zipfilename, "r") as zip_file: + zip_file.extractall(target_path) + + +def mkdir_with_umask(directory): + oldmask = os.umask(0) + # mode needs octal + os.makedirs(directory, mode=0o775, exist_ok=True) + os.umask(oldmask) + + +def open_folder(path: str): + platform = get_platform() + try: + if platform == "darwin": + subprocess.call(("open", path)) + elif platform in ["win64", "win32"]: + open_with_default_app(path) + elif platform == "wsl": + subprocess.call("cmd.exe /C start".split() + [path]) + else: # linux variants + subprocess.Popen(["xdg-open", path]) + except Exception as e: + traceback.print_exc() + pass + # import webbrowser + # webbrowser.open(url) + + +def create_version_html() -> str: + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + versions_html = f""" +python: {python_version} +โ€ข +torch: {getattr(torch, '__long_version__',torch.__version__)} +โ€ข +gradio: {gradio.__version__} +""" + return versions_html + + +def compute_cosine_distance(emb1, emb2) -> float: + return distance.cosine(emb1, emb2) diff --git a/roop-unleashed/roop/virtualcam.py b/roop-unleashed/roop/virtualcam.py new file mode 100644 index 0000000000000000000000000000000000000000..d429851bb610789386a4d11866d2663f43bd78be --- /dev/null +++ b/roop-unleashed/roop/virtualcam.py @@ -0,0 +1,87 @@ +import cv2 +import roop.globals +import ui.globals +import pyvirtualcam +import threading +import platform + + +cam_active = False +cam_thread = None +vcam = None + +def virtualcamera(streamobs, cam_num,width,height): + from roop.ProcessOptions import ProcessOptions + from roop.core import live_swap, get_processing_plugins + + global cam_active + + #time.sleep(2) + print('Starting capture') + cap = cv2.VideoCapture(cam_num, cv2.CAP_DSHOW if platform.system() != 'Darwin' else cv2.CAP_AVFOUNDATION) + if not cap.isOpened(): + print("Cannot open camera") + cap.release() + del cap + return + + pref_width = width + pref_height = height + pref_fps_in = 30 + cap.set(cv2.CAP_PROP_FRAME_WIDTH, pref_width) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, pref_height) + cap.set(cv2.CAP_PROP_FPS, pref_fps_in) + cam_active = True + + # native format UYVY + + cam = None + if streamobs: + print('Detecting virtual cam devices') + cam = pyvirtualcam.Camera(width=pref_width, height=pref_height, fps=pref_fps_in, fmt=pyvirtualcam.PixelFormat.BGR, print_fps=False) + if cam: + print(f'Using virtual camera: {cam.device}') + print(f'Using {cam.native_fmt}') + else: + print(f'Not streaming to virtual camera!') + + # always use xseg masking + options = ProcessOptions(get_processing_plugins("mask_xseg"), roop.globals.distance_threshold, roop.globals.blend_ratio, + "all", 0, None, None, 1, False) + while cam_active: + ret, frame = cap.read() + if not ret: + break + + if len(roop.globals.INPUT_FACESETS) > 0: + frame = live_swap(frame, options) + if cam: + cam.send(frame) + cam.sleep_until_next_frame() + ui.globals.ui_camera_frame = frame + + if cam: + cam.close() + cap.release() + print('Camera stopped') + + + +def start_virtual_cam(streamobs, cam_number, resolution): + global cam_thread, cam_active + + if not cam_active: + width, height = map(int, resolution.split('x')) + cam_thread = threading.Thread(target=virtualcamera, args=[streamobs, cam_number, width, height]) + cam_thread.start() + + + +def stop_virtual_cam(): + global cam_active, cam_thread + + if cam_active: + cam_active = False + cam_thread.join() + + diff --git a/roop-unleashed/roop/vr_util.py b/roop-unleashed/roop/vr_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a72845e3c2c3cc89f6567ebfc13bf77d306710ff --- /dev/null +++ b/roop-unleashed/roop/vr_util.py @@ -0,0 +1,57 @@ +import cv2 +import numpy as np + +# VR Lense Distortion +# Taken from https://github.com/g0kuvonlange/vrswap + + +def get_perspective(img, FOV, THETA, PHI, height, width): + # + # THETA is left/right angle, PHI is up/down angle, both in degree + # + [orig_width, orig_height, _] = img.shape + equ_h = orig_height + equ_w = orig_width + equ_cx = (equ_w - 1) / 2.0 + equ_cy = (equ_h - 1) / 2.0 + + wFOV = FOV + hFOV = float(height) / width * wFOV + + w_len = np.tan(np.radians(wFOV / 2.0)) + h_len = np.tan(np.radians(hFOV / 2.0)) + + x_map = np.ones([height, width], np.float32) + y_map = np.tile(np.linspace(-w_len, w_len, width), [height, 1]) + z_map = -np.tile(np.linspace(-h_len, h_len, height), [width, 1]).T + + D = np.sqrt(x_map**2 + y_map**2 + z_map**2) + xyz = np.stack((x_map, y_map, z_map), axis=2) / np.repeat( + D[:, :, np.newaxis], 3, axis=2 + ) + + y_axis = np.array([0.0, 1.0, 0.0], np.float32) + z_axis = np.array([0.0, 0.0, 1.0], np.float32) + [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA)) + [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(-PHI)) + + xyz = xyz.reshape([height * width, 3]).T + xyz = np.dot(R1, xyz) + xyz = np.dot(R2, xyz).T + lat = np.arcsin(xyz[:, 2]) + lon = np.arctan2(xyz[:, 1], xyz[:, 0]) + + lon = lon.reshape([height, width]) / np.pi * 180 + lat = -lat.reshape([height, width]) / np.pi * 180 + + lon = lon / 180 * equ_cx + equ_cx + lat = lat / 90 * equ_cy + equ_cy + + persp = cv2.remap( + img, + lon.astype(np.float32), + lat.astype(np.float32), + cv2.INTER_CUBIC, + borderMode=cv2.BORDER_WRAP, + ) + return persp diff --git a/roop-unleashed/run.py b/roop-unleashed/run.py new file mode 100755 index 0000000000000000000000000000000000000000..b52e5cc4a8ea9ce5cadd4e7111fb15531f380314 --- /dev/null +++ b/roop-unleashed/run.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from roop import core + +if __name__ == '__main__': + core.run() diff --git a/roop-unleashed/settings.py b/roop-unleashed/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..eaed8e0d33375c38c0bf44f0d79c96f0c646c36d --- /dev/null +++ b/roop-unleashed/settings.py @@ -0,0 +1,68 @@ +import yaml + +class Settings: + def __init__(self, config_file): + self.config_file = config_file + self.load() + + def default_get(_, data, name, default): + value = default + try: + value = data.get(name, default) + except: + pass + return value + + + def load(self): + try: + with open(self.config_file, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + except: + data = None + + self.selected_theme = self.default_get(data, 'selected_theme', "Default") + self.server_name = self.default_get(data, 'server_name', "") + self.server_port = self.default_get(data, 'server_port', 0) + self.server_share = self.default_get(data, 'server_share', False) + self.output_image_format = self.default_get(data, 'output_image_format', 'png') + self.output_video_format = self.default_get(data, 'output_video_format', 'mp4') + self.output_video_codec = self.default_get(data, 'output_video_codec', 'libx264') + self.video_quality = self.default_get(data, 'video_quality', 14) + self.clear_output = self.default_get(data, 'clear_output', True) + self.max_threads = self.default_get(data, 'max_threads', 2) + self.memory_limit = self.default_get(data, 'memory_limit', 0) + self.provider = self.default_get(data, 'provider', 'cuda') + self.force_cpu = self.default_get(data, 'force_cpu', False) + self.output_template = self.default_get(data, 'output_template', '{file}_{time}') + self.use_os_temp_folder = self.default_get(data, 'use_os_temp_folder', False) + self.output_show_video = self.default_get(data, 'output_show_video', True) + + + + + + def save(self): + data = { + 'selected_theme': self.selected_theme, + 'server_name': self.server_name, + 'server_port': self.server_port, + 'server_share': self.server_share, + 'output_image_format' : self.output_image_format, + 'output_video_format' : self.output_video_format, + 'output_video_codec' : self.output_video_codec, + 'video_quality' : self.video_quality, + 'clear_output' : self.clear_output, + 'max_threads' : self.max_threads, + 'memory_limit' : self.memory_limit, + 'provider' : self.provider, + 'force_cpu' : self.force_cpu, + 'output_template' : self.output_template, + 'use_os_temp_folder' : self.use_os_temp_folder, + 'output_show_video' : self.output_show_video + } + with open(self.config_file, 'w') as f: + yaml.dump(data, f) + + + diff --git a/roop-unleashed/ui/globals.py b/roop-unleashed/ui/globals.py new file mode 100644 index 0000000000000000000000000000000000000000..5514a63d6e6e00bfb72938f8648e7eb5575d601a --- /dev/null +++ b/roop-unleashed/ui/globals.py @@ -0,0 +1,15 @@ +ui_restart_server = False + +SELECTION_FACES_DATA = None +ui_SELECTED_INPUT_FACE_INDEX = 0 + +ui_selected_enhancer = None +ui_blend_ratio = None +ui_input_thumbs = [] +ui_target_thumbs = [] +ui_camera_frame = None + + + + + diff --git a/roop-unleashed/ui/main.py b/roop-unleashed/ui/main.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf64b1b63f132119432cc795ca76fcb5b134200 --- /dev/null +++ b/roop-unleashed/ui/main.py @@ -0,0 +1,88 @@ +import os +import time +import gradio as gr +import roop.globals +import roop.metadata +import roop.utilities as util +import ui.globals as uii + +from ui.tabs.faceswap_tab import faceswap_tab +from ui.tabs.livecam_tab import livecam_tab +from ui.tabs.facemgr_tab import facemgr_tab +from ui.tabs.extras_tab import extras_tab +from ui.tabs.settings_tab import settings_tab + +roop.globals.keep_fps = None +roop.globals.keep_frames = None +roop.globals.skip_audio = None +roop.globals.use_batch = None + + +def prepare_environment(): + roop.globals.output_path = os.path.abspath(os.path.join(os.getcwd(), "output")) + os.makedirs(roop.globals.output_path, exist_ok=True) + if not roop.globals.CFG.use_os_temp_folder: + os.environ["TEMP"] = os.environ["TMP"] = os.path.abspath(os.path.join(os.getcwd(), "temp")) + os.makedirs(os.environ["TEMP"], exist_ok=True) + os.environ["GRADIO_TEMP_DIR"] = os.environ["TEMP"] + + +def run(): + from roop.core import decode_execution_providers, set_display_ui + + prepare_environment() + + set_display_ui(show_msg) + roop.globals.execution_providers = decode_execution_providers([roop.globals.CFG.provider]) + print(f'Using provider {roop.globals.execution_providers} - Device:{util.get_device()}') + + run_server = True + uii.ui_restart_server = False + mycss = """ + span {color: var(--block-info-text-color)} + #fixedheight { + max-height: 238.4px; + overflow-y: auto !important; + } + .image-container.svelte-1l6wqyv {height: 100%} + + """ + + while run_server: + server_name = roop.globals.CFG.server_name + if server_name is None or len(server_name) < 1: + server_name = None + server_port = roop.globals.CFG.server_port + if server_port <= 0: + server_port = None + ssl_verify = False if server_name == '0.0.0.0' else True + with gr.Blocks(title=f'{roop.metadata.name} {roop.metadata.version}', theme=roop.globals.CFG.selected_theme, css=mycss) as ui: + with gr.Row(variant='compact'): + gr.Markdown(f"### [{roop.metadata.name} {roop.metadata.version}](https://github.com/C0untFloyd/roop-unleashed)") + gr.HTML(util.create_version_html(), elem_id="versions") + faceswap_tab() + livecam_tab() + facemgr_tab() + extras_tab() + settings_tab() + + uii.ui_restart_server = False + try: + ui.queue().launch(inbrowser=True, server_name=server_name, server_port=server_port, share=roop.globals.CFG.server_share, ssl_verify=ssl_verify, prevent_thread_lock=True, show_error=True) + except Exception as e: + print(f'Exception {e} when launching Gradio Server!') + uii.ui_restart_server = True + run_server = False + try: + while uii.ui_restart_server == False: + time.sleep(1.0) + + except (KeyboardInterrupt, OSError): + print("Keyboard interruption in main thread... closing server.") + run_server = False + ui.close() + + +def show_msg(msg: str): + gr.Info(msg) + diff --git a/roop-unleashed/ui/tabs/extras_tab.py b/roop-unleashed/ui/tabs/extras_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..7686542649857d1d2b61722460066ea71f38d12d --- /dev/null +++ b/roop-unleashed/ui/tabs/extras_tab.py @@ -0,0 +1,184 @@ +import os +import gradio as gr +import shutil +import roop.utilities as util +import roop.util_ffmpeg as ffmpeg +import roop.globals + +frame_filters_map = { + "Colorize B/W Images (Deoldify Artistic)" : {"colorizer" : {"subtype": "deoldify_artistic"}}, + "Colorize B/W Images (Deoldify Stable)" : {"colorizer" : {"subtype": "deoldify_stable"}}, + "Background remove" : {"removebg" : {"subtype": ""}}, + "Filter Stylize" : {"filter_generic" : {"subtype" : "stylize" }}, + "Filter Detail Enhance" : {"filter_generic" : {"subtype" : "detailenhance" }}, + "Filter Pencil Sketch" : {"filter_generic" : {"subtype" : "pencil" }}, + "Filter Cartoon" : {"filter_generic" : {"subtype" : "cartoon" }}, + "Filter C64" : {"filter_generic" : {"subtype" : "C64" }} + } + +frame_upscalers_map = { + "ESRGAN x2" : {"upscale" : {"subtype": "esrganx2"}}, + "ESRGAN x4" : {"upscale" : {"subtype": "esrganx4"}}, + "LSDIR x4" : {"upscale" : {"subtype": "lsdirx4"}} +} + +def extras_tab(): + filternames = ["None"] + for f in frame_filters_map.keys(): + filternames.append(f) + upscalernames = ["None"] + for f in frame_upscalers_map.keys(): + upscalernames.append(f) + + with gr.Tab("๐ŸŽ‰ Extras"): + with gr.Row(): + files_to_process = gr.Files(label='File(s) to process', file_count="multiple", file_types=["image", "video"]) + with gr.Row(variant='panel'): + with gr.Accordion(label="Video/GIF", open=False): + with gr.Row(variant='panel'): + with gr.Column(): + gr.Markdown(""" + # Poor man's video editor + Re-encoding uses your configuration from the Settings Tab. + """) + with gr.Column(): + cut_start_time = gr.Slider(0, 1000000, value=0, label="Start Frame", step=1.0, interactive=True) + with gr.Column(): + cut_end_time = gr.Slider(1, 1000000, value=1, label="End Frame", step=1.0, interactive=True) + with gr.Column(): + extras_chk_encode = gr.Checkbox(label='Re-encode videos (necessary for videos with different codecs)', value=False) + start_cut_video = gr.Button("Cut video") + start_extract_frames = gr.Button("Extract frames") + start_join_videos = gr.Button("Join videos") + + with gr.Row(variant='panel'): + with gr.Column(): + gr.Markdown(""" + # Create video/gif from images + """) + with gr.Column(): + extras_fps = gr.Slider(minimum=0, maximum=120, value=30, label="Video FPS", step=1.0, interactive=True) + extras_images_folder = gr.Textbox(show_label=False, placeholder="/content/", interactive=True) + with gr.Column(): + extras_chk_creategif = gr.Checkbox(label='Create GIF from video', value=False) + extras_create_video=gr.Button("Create") + with gr.Row(variant='panel'): + with gr.Accordion(label="Full frame processing", open=True): + with gr.Row(variant='panel'): + filterselection = gr.Dropdown(filternames, value="None", label="Colorizer/FilterFX", interactive=True) + upscalerselection = gr.Dropdown(upscalernames, value="None", label="Enhancer", interactive=True) + with gr.Row(variant='panel'): + start_frame_process=gr.Button("Start processing") + + with gr.Row(): + gr.Button("๐Ÿ‘€ Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path)) + with gr.Row(): + extra_files_output = gr.Files(label='Resulting output files', file_count="multiple") + + start_cut_video.click(fn=on_cut_video, inputs=[files_to_process, cut_start_time, cut_end_time, extras_chk_encode], outputs=[extra_files_output]) + start_extract_frames.click(fn=on_extras_extract_frames, inputs=[files_to_process], outputs=[extra_files_output]) + start_join_videos.click(fn=on_join_videos, inputs=[files_to_process, extras_chk_encode], outputs=[extra_files_output]) + extras_create_video.click(fn=on_extras_create_video, inputs=[extras_images_folder, extras_fps, extras_chk_creategif], outputs=[extra_files_output]) + start_frame_process.click(fn=on_frame_process, inputs=[files_to_process, filterselection, upscalerselection], outputs=[extra_files_output]) + + +def on_cut_video(files, cut_start_frame, cut_end_frame, reencode): + if files is None: + return None + + resultfiles = [] + for tf in files: + f = tf.name + destfile = util.get_destfilename_from_path(f, roop.globals.output_path, '_cut') + ffmpeg.cut_video(f, destfile, cut_start_frame, cut_end_frame, reencode) + if os.path.isfile(destfile): + resultfiles.append(destfile) + else: + gr.Error('Cutting video failed!') + return resultfiles + + +def on_join_videos(files, chk_encode): + if files is None: + return None + + filenames = [] + for f in files: + filenames.append(f.name) + destfile = util.get_destfilename_from_path(filenames[0], roop.globals.output_path, '_join') + sorted_filenames = util.sort_filenames_ignore_path(filenames) + ffmpeg.join_videos(sorted_filenames, destfile, not chk_encode) + resultfiles = [] + if os.path.isfile(destfile): + resultfiles.append(destfile) + else: + gr.Error('Joining videos failed!') + return resultfiles + + + +def on_extras_create_video(images_path,fps, create_gif): + util.sort_rename_frames(os.path.dirname(images_path)) + destfilename = os.path.join(roop.globals.output_path, "img2video." + roop.globals.CFG.output_video_format) + ffmpeg.create_video('', destfilename, fps, images_path) + resultfiles = [] + if os.path.isfile(destfilename): + resultfiles.append(destfilename) + else: + return None + if create_gif: + gifname = util.get_destfilename_from_path(destfilename, './output', '.gif') + ffmpeg.create_gif_from_video(destfilename, gifname) + if os.path.isfile(destfilename): + resultfiles.append(gifname) + return resultfiles + + +def on_extras_extract_frames(files): + if files is None: + return None + + resultfiles = [] + for tf in files: + f = tf.name + resfolder = ffmpeg.extract_frames(f) + for file in os.listdir(resfolder): + outfile = os.path.join(resfolder, file) + if os.path.isfile(outfile): + resultfiles.append(outfile) + return resultfiles + + +def on_frame_process(files, filterselection, upscaleselection): + import pathlib + from roop.core import batch_process_with_options + from roop.ProcessEntry import ProcessEntry + from roop.ProcessOptions import ProcessOptions + from ui.main import prepare_environment + + + if files is None: + return None + + if roop.globals.CFG.clear_output: + shutil.rmtree(roop.globals.output_path) + prepare_environment() + list_files_process : list[ProcessEntry] = [] + + for tf in files: + list_files_process.append(ProcessEntry(tf.name, 0,0, 0)) + + processoroptions = {} + filter = next((x for x in frame_filters_map.keys() if x == filterselection), None) + if filter is not None: + processoroptions.update(frame_filters_map[filter]) + filter = next((x for x in frame_upscalers_map.keys() if x == upscaleselection), None) + if filter is not None: + processoroptions.update(frame_upscalers_map[filter]) + options = ProcessOptions(processoroptions, 0, 0, "all", 0, None, None, None, False) + batch_process_with_options(list_files_process, options, None) + outdir = pathlib.Path(roop.globals.output_path) + outfiles = [str(item) for item in outdir.rglob("*") if item.is_file()] + return outfiles + + diff --git a/roop-unleashed/ui/tabs/facemgr_tab.py b/roop-unleashed/ui/tabs/facemgr_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..05a5ac3f53c02703a8b14216b1fd051b57f87977 --- /dev/null +++ b/roop-unleashed/ui/tabs/facemgr_tab.py @@ -0,0 +1,187 @@ +import os +import shutil +import cv2 +import gradio as gr +import roop.utilities as util +import roop.globals +from roop.face_util import extract_face_images +from roop.capturer import get_video_frame, get_video_frame_total +from typing import List, Tuple, Optional +from roop.typing import Frame, Face, FaceSet + +selected_face_index = -1 +thumbs = [] +images = [] + + +def facemgr_tab() -> None: + with gr.Tab("๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ Face Management"): + with gr.Row(): + gr.Markdown(""" + # Create blending facesets + Add multiple reference images into a faceset file. + """) + with gr.Row(): + videoimagefst = gr.Image(label="Cut face from video frame", height=576, interactive=False, visible=True) + with gr.Row(): + frame_num_fst = gr.Slider(1, 1, value=1, label="Frame Number", info='0:00:00', step=1.0, interactive=False) + fb_cutfromframe = gr.Button("Use faces from this frame", variant='secondary', interactive=False) + with gr.Row(): + fb_facesetfile = gr.Files(label='Faceset', file_count='single', file_types=['.fsz'], interactive=True) + fb_files = gr.Files(label='Input Files', file_count="multiple", file_types=["image", "video"], interactive=True) + with gr.Row(): + with gr.Column(): + gr.Button("๐Ÿ‘€ Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path)) + with gr.Column(): + gr.Markdown(' ') + with gr.Row(): + faces = gr.Gallery(label="Faces in this Faceset", allow_preview=True, preview=True, height=128, object_fit="scale-down") + with gr.Row(): + fb_remove = gr.Button("Remove selected", variant='secondary') + fb_update = gr.Button("Create/Update Faceset file", variant='primary') + fb_clear = gr.Button("Clear all", variant='stop') + + fb_facesetfile.change(fn=on_faceset_changed, inputs=[fb_facesetfile], outputs=[faces]) + fb_files.change(fn=on_fb_files_changed, inputs=[fb_files], outputs=[faces, videoimagefst, frame_num_fst, fb_cutfromframe]) + fb_update.click(fn=on_update_clicked, outputs=[fb_facesetfile]) + fb_remove.click(fn=on_remove_clicked, outputs=[faces]) + fb_clear.click(fn=on_clear_clicked, outputs=[faces, fb_files, fb_facesetfile]) + fb_cutfromframe.click(fn=on_cutfromframe_clicked, inputs=[fb_files, frame_num_fst], outputs=[faces]) + frame_num_fst.release(fn=on_frame_num_fst_changed, inputs=[fb_files, frame_num_fst], outputs=[videoimagefst]) + faces.select(fn=on_face_selected) + + +def on_faceset_changed(faceset, progress=gr.Progress()) -> List[Frame]: + global thumbs, images + + if faceset is None: + return thumbs + + thumbs.clear() + filename = faceset.name + + if filename.lower().endswith('fsz'): + progress(0, desc="Retrieving faces from Faceset File", ) + unzipfolder = os.path.join(os.environ["TEMP"], 'faceset') + if os.path.isdir(unzipfolder): + shutil.rmtree(unzipfolder) + util.mkdir_with_umask(unzipfolder) + util.unzip(filename, unzipfolder) + for file in os.listdir(unzipfolder): + if file.endswith(".png"): + SELECTION_FACES_DATA = extract_face_images(os.path.join(unzipfolder,file), (False, 0), 0.5) + if len(SELECTION_FACES_DATA) < 1: + gr.Warning(f"No face detected in {file}!") + for f in SELECTION_FACES_DATA: + image = f[1] + images.append(image) + thumbs.append(util.convert_to_gradio(image)) + + return thumbs + + +def on_fb_files_changed(inputfiles, progress=gr.Progress()) -> Tuple[List[Frame], Optional[gr.Image], Optional[gr.Slider], Optional[gr.Button]]: + global thumbs, images, total_frames, current_video_fps + + if inputfiles is None or len(inputfiles) < 1: + return thumbs, None, None, None + + progress(0, desc="Retrieving faces from images", ) + slider = None + video_image = None + cut_button = None + for f in inputfiles: + source_path = f.name + if util.has_image_extension(source_path): + slider = gr.Slider(interactive=False) + video_image = gr.Image(interactive=False) + cut_button = gr.Button(interactive=False) + roop.globals.source_path = source_path + SELECTION_FACES_DATA = extract_face_images(roop.globals.source_path, (False, 0), 0.5) + for f in SELECTION_FACES_DATA: + image = f[1] + images.append(image) + thumbs.append(util.convert_to_gradio(image)) + elif util.is_video(source_path) or source_path.lower().endswith('gif'): + total_frames = get_video_frame_total(source_path) + current_video_fps = util.detect_fps(source_path) + cut_button = gr.Button(interactive=True) + video_image, slider = display_video_frame(source_path, 1, total_frames) + + return thumbs, video_image, slider, cut_button + + +def display_video_frame(filename: str, frame_num: int, total: int=0) -> Tuple[gr.Image, gr.Slider]: + global current_video_fps + + current_frame = get_video_frame(filename, frame_num) + if current_video_fps == 0: + current_video_fps = 1 + secs = (frame_num - 1) / current_video_fps + minutes = secs / 60 + secs = secs % 60 + hours = minutes / 60 + minutes = minutes % 60 + milliseconds = (secs - int(secs)) * 1000 + timeinfo = f"{int(hours):0>2}:{int(minutes):0>2}:{int(secs):0>2}.{int(milliseconds):0>3}" + if total > 0: + return gr.Image(value=util.convert_to_gradio(current_frame), interactive=True), gr.Slider(info=timeinfo, minimum=1, maximum=total, interactive=True) + return gr.Image(value=util.convert_to_gradio(current_frame), interactive=True), gr.Slider(info=timeinfo, interactive=True) + + +def on_face_selected(evt: gr.SelectData) -> None: + global selected_face_index + + if evt is not None: + selected_face_index = evt.index + +def on_frame_num_fst_changed(inputfiles: List[gr.Files], frame_num: int) -> Frame: + filename = inputfiles[0].name + video_image, _ = display_video_frame(filename, frame_num, 0) + return video_image + + +def on_cutfromframe_clicked(inputfiles: List[gr.Files], frame_num: int) -> List[Frame]: + global thumbs + + filename = inputfiles[0].name + SELECTION_FACES_DATA = extract_face_images(filename, (True, frame_num), 0.5) + for f in SELECTION_FACES_DATA: + image = f[1] + images.append(image) + thumbs.append(util.convert_to_gradio(image)) + return thumbs + + +def on_remove_clicked() -> List[Frame]: + global thumbs, images, selected_face_index + + if len(thumbs) > selected_face_index: + f = thumbs.pop(selected_face_index) + del f + f = images.pop(selected_face_index) + del f + return thumbs + +def on_clear_clicked() -> Tuple[List[Frame], None, None]: + global thumbs, images + + thumbs.clear() + images.clear() + return thumbs, None, None + + +def on_update_clicked() -> Optional[str]: + if len(images) < 1: + gr.Warning(f"No faces to create faceset from!") + return None + + imgnames = [] + for index,img in enumerate(images): + filename = os.path.join(roop.globals.output_path, f'{index}.png') + cv2.imwrite(filename, img) + imgnames.append(filename) + + finalzip = os.path.join(roop.globals.output_path, 'faceset.fsz') + util.zip(imgnames, finalzip) + return finalzip diff --git a/roop-unleashed/ui/tabs/faceswap_tab.py b/roop-unleashed/ui/tabs/faceswap_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..97ec1fab45ce21296014c9cc896c072494a582c2 --- /dev/null +++ b/roop-unleashed/ui/tabs/faceswap_tab.py @@ -0,0 +1,717 @@ +import os +import shutil +import pathlib +import gradio as gr +import roop.utilities as util +import roop.globals +import ui.globals +from roop.face_util import extract_face_images, create_blank_image +from roop.capturer import get_video_frame, get_video_frame_total, get_image_frame +from roop.ProcessEntry import ProcessEntry +from roop.ProcessOptions import ProcessOptions +from roop.FaceSet import FaceSet + +last_image = None + + +IS_INPUT = True +SELECTED_FACE_INDEX = 0 + +SELECTED_INPUT_FACE_INDEX = 0 +SELECTED_TARGET_FACE_INDEX = 0 + +input_faces = None +target_faces = None +face_selection = None +previewimage = None + +selected_preview_index = 0 + +is_processing = False + +list_files_process : list[ProcessEntry] = [] +no_face_choices = ["Use untouched original frame","Retry rotated", "Skip Frame", "Skip Frame if no similar face"] + +current_video_fps = 50 + +manual_masking = False + + +def faceswap_tab(): + global no_face_choices, previewimage + + with gr.Tab("๐ŸŽญ Face Swap"): + with gr.Row(variant='panel'): + with gr.Column(scale=2): + with gr.Row(): + with gr.Column(min_width=160): + input_faces = gr.Gallery(label="Input faces", allow_preview=False, preview=False, height=128, object_fit="scale-down", columns=8) + with gr.Accordion(label="Advanced Masking", open=False): + chk_showmaskoffsets = gr.Checkbox(label="Show mask overlay in preview", value=False, interactive=True) + mask_top = gr.Slider(0, 1.0, value=0, label="Offset Face Top", step=0.01, interactive=True) + mask_bottom = gr.Slider(0, 1.0, value=0, label="Offset Face Bottom", step=0.01, interactive=True) + mask_left = gr.Slider(0, 1.0, value=0, label="Offset Face Left", step=0.01, interactive=True) + mask_right = gr.Slider(0, 1.0, value=0, label="Offset Face Right", step=0.01, interactive=True) + mask_erosion = gr.Slider(1.0, 3.0, value=1.0, label="Erosion Iterations", step=1.00, interactive=True) + mask_blur = gr.Slider(10.0, 50.0, value=20.0, label="Blur size", step=1.00, interactive=True) + bt_toggle_masking = gr.Button("Toggle manual masking", variant='secondary', size='sm') + selected_mask_engine = gr.Dropdown(["None", "Clip2Seg", "DFL XSeg"], value="None", label="Face masking engine") + clip_text = gr.Textbox(label="List of objects to mask and restore back on fake face", value="cup,hands,hair,banana", interactive=False) + bt_preview_mask = gr.Button("๐Ÿ‘ฅ Show Mask Preview", variant='secondary') + bt_remove_selected_input_face = gr.Button("โŒ Remove selected", size='sm') + bt_clear_input_faces = gr.Button("๐Ÿ’ฅ Clear all", variant='stop', size='sm') + with gr.Column(min_width=160): + target_faces = gr.Gallery(label="Target faces", allow_preview=False, preview=False, height=128, object_fit="scale-down", columns=8) + bt_remove_selected_target_face = gr.Button("โŒ Remove selected", size='sm') + bt_add_local = gr.Button('Add local files from', size='sm') + local_folder = gr.Textbox(show_label=False, placeholder="/content/", interactive=True) + with gr.Row(variant='panel'): + bt_srcfiles = gr.Files(label='Source File(s)', file_count="multiple", file_types=["image", ".fsz"], elem_id='filelist', height=233) + bt_destfiles = gr.Files(label='Target File(s)', file_count="multiple", file_types=["image", "video"], elem_id='filelist', height=233) + with gr.Row(variant='panel'): + gr.Markdown('') + forced_fps = gr.Slider(minimum=0, maximum=120, value=0, label="Video FPS", info='Overrides detected fps if not 0', step=1.0, interactive=True, container=True) + + with gr.Column(scale=2): + previewimage = gr.Image(label="Preview Image", height=576, interactive=False, visible=True) + maskimage = gr.ImageEditor(label="Manual mask Image", sources=["clipboard"], transforms="", type="numpy", + brush=gr.Brush(color_mode="fixed", colors=["rgba(255, 255, 255, 1"]), interactive=True, visible=False) + with gr.Row(variant='panel'): + fake_preview = gr.Checkbox(label="Face swap frames", value=False) + bt_refresh_preview = gr.Button("๐Ÿ”„ Refresh", variant='secondary', size='sm') + bt_use_face_from_preview = gr.Button("Use Face from this Frame", variant='primary', size='sm') + with gr.Row(): + preview_frame_num = gr.Slider(1, 1, value=1, label="Frame Number", info='0:00:00', step=1.0, interactive=True) + with gr.Row(): + text_frame_clip = gr.Markdown('Processing frame range [0 - 0]') + set_frame_start = gr.Button("โฌ… Set as Start", size='sm') + set_frame_end = gr.Button("โžก Set as End", size='sm') + with gr.Row(visible=False) as dynamic_face_selection: + with gr.Column(scale=2): + face_selection = gr.Gallery(label="Detected faces", allow_preview=False, preview=False, height=256, object_fit="cover", columns=8) + with gr.Column(): + bt_faceselect = gr.Button("โ˜‘ Use selected face", size='sm') + bt_cancelfaceselect = gr.Button("Done", size='sm') + with gr.Column(): + gr.Markdown(' ') + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + selected_face_detection = gr.Dropdown(["First found", "All female", "All male", "All faces", "Selected face"], value="First found", label="Specify face selection for swapping") + with gr.Column(scale=1): + ui.globals.ui_selected_enhancer = gr.Dropdown(["None", "Codeformer", "DMDNet", "GFPGAN", "GPEN", "Restoreformer++"], value="None", label="Select post-processing") + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + max_face_distance = gr.Slider(0.01, 1.0, value=0.65, label="Max Face Similarity Threshold", info="0.0 = identical 1.0 = no similarity") + with gr.Column(scale=1): + num_swap_steps = gr.Slider(1, 5, value=1, step=1.0, label="Number of swapping steps", info="More steps can increase likeness") + with gr.Column(scale=2): + ui.globals.ui_blend_ratio = gr.Slider(0.0, 1.0, value=0.65, label="Original/Enhanced image blend ratio", info="Only used with active post-processing") + + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + video_swapping_method = gr.Dropdown(["Extract Frames to media","In-Memory processing"], value="In-Memory processing", label="Select video processing method", interactive=True) + no_face_action = gr.Dropdown(choices=no_face_choices, value=no_face_choices[0], label="Action on no face detected", interactive=True) + vr_mode = gr.Checkbox(label="VR Mode", value=False) + with gr.Column(scale=1): + with gr.Group(): + autorotate = gr.Checkbox(label="Auto rotate horizontal Faces", value=True) + roop.globals.skip_audio = gr.Checkbox(label="Skip audio", value=False) + roop.globals.keep_frames = gr.Checkbox(label="Keep Frames (relevant only when extracting frames)", value=False) + roop.globals.wait_after_extraction = gr.Checkbox(label="Wait for user key press before creating video ", value=False) + + + + with gr.Row(variant='panel'): + with gr.Column(): + bt_start = gr.Button("โ–ถ Start", variant='primary') + gr.Button("๐Ÿ‘€ Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path)) + with gr.Column(): + bt_stop = gr.Button("โน Stop", variant='secondary', interactive=False) + with gr.Column(scale=2): + gr.Markdown(' ') + with gr.Row(variant='panel'): + with gr.Column(): + resultfiles = gr.Files(label='Processed File(s)', interactive=False) + with gr.Column(): + resultimage = gr.Image(type='filepath', label='Final Image', interactive=False ) + resultvideo = gr.Video(label='Final Video', interactive=False, visible=False) + + previewinputs = [preview_frame_num, bt_destfiles, fake_preview, ui.globals.ui_selected_enhancer, selected_face_detection, + max_face_distance, ui.globals.ui_blend_ratio, selected_mask_engine, clip_text, no_face_action, vr_mode, autorotate, maskimage, chk_showmaskoffsets, num_swap_steps] + previewoutputs = [previewimage, maskimage, preview_frame_num] + input_faces.select(on_select_input_face, None, None).then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs) + bt_remove_selected_input_face.click(fn=remove_selected_input_face, outputs=[input_faces]) + bt_srcfiles.change(fn=on_srcfile_changed, show_progress='full', inputs=bt_srcfiles, outputs=[dynamic_face_selection, face_selection, input_faces]) + + mask_top.release(fn=on_mask_top_changed, inputs=[mask_top], show_progress='hidden') + mask_bottom.release(fn=on_mask_bottom_changed, inputs=[mask_bottom], show_progress='hidden') + mask_left.release(fn=on_mask_left_changed, inputs=[mask_left], show_progress='hidden') + mask_right.release(fn=on_mask_right_changed, inputs=[mask_right], show_progress='hidden') + mask_erosion.release(fn=on_mask_erosion_changed, inputs=[mask_erosion], show_progress='hidden') + mask_blur.release(fn=on_mask_blur_changed, inputs=[mask_blur], show_progress='hidden') + selected_mask_engine.change(fn=on_mask_engine_changed, inputs=[selected_mask_engine], outputs=[clip_text], show_progress='hidden') + + + target_faces.select(on_select_target_face, None, None) + bt_remove_selected_target_face.click(fn=remove_selected_target_face, outputs=[target_faces]) + + forced_fps.change(fn=on_fps_changed, inputs=[forced_fps], show_progress='hidden') + bt_destfiles.change(fn=on_destfiles_changed, inputs=[bt_destfiles], outputs=[preview_frame_num, text_frame_clip], show_progress='hidden').then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden') + bt_destfiles.select(fn=on_destfiles_selected, outputs=[preview_frame_num, text_frame_clip, forced_fps], show_progress='hidden').then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden') + bt_destfiles.clear(fn=on_clear_destfiles, outputs=[target_faces, selected_face_detection]) + resultfiles.select(fn=on_resultfiles_selected, inputs=[resultfiles], outputs=[resultimage, resultvideo]) + + face_selection.select(on_select_face, None, None) + bt_faceselect.click(fn=on_selected_face, outputs=[input_faces, target_faces, selected_face_detection]) + bt_cancelfaceselect.click(fn=on_end_face_selection, outputs=[dynamic_face_selection, face_selection]) + + bt_clear_input_faces.click(fn=on_clear_input_faces, outputs=[input_faces]) + + + bt_add_local.click(fn=on_add_local_folder, inputs=[local_folder], outputs=[bt_destfiles]) + bt_preview_mask.click(fn=on_preview_mask, inputs=[preview_frame_num, bt_destfiles, clip_text, selected_mask_engine], outputs=[previewimage]) + + start_event = bt_start.click(fn=start_swap, + inputs=[ui.globals.ui_selected_enhancer, selected_face_detection, roop.globals.keep_frames, roop.globals.wait_after_extraction, + roop.globals.skip_audio, max_face_distance, ui.globals.ui_blend_ratio, selected_mask_engine, clip_text,video_swapping_method, no_face_action, vr_mode, autorotate, num_swap_steps, maskimage], + outputs=[bt_start, bt_stop, resultfiles], show_progress='full') + after_swap_event = start_event.then(fn=on_resultfiles_finished, inputs=[resultfiles], outputs=[resultimage, resultvideo]) + + bt_stop.click(fn=stop_swap, cancels=[start_event, after_swap_event], outputs=[bt_start, bt_stop], queue=False) + + bt_refresh_preview.click(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs) + bt_toggle_masking.click(fn=on_toggle_masking, inputs=[previewimage, maskimage], outputs=[previewimage, maskimage]) + fake_preview.change(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs) + preview_frame_num.release(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden', ) + bt_use_face_from_preview.click(fn=on_use_face_from_selected, show_progress='full', inputs=[bt_destfiles, preview_frame_num], outputs=[dynamic_face_selection, face_selection, target_faces, selected_face_detection]) + set_frame_start.click(fn=on_set_frame, inputs=[set_frame_start, preview_frame_num], outputs=[text_frame_clip]) + set_frame_end.click(fn=on_set_frame, inputs=[set_frame_end, preview_frame_num], outputs=[text_frame_clip]) + + + +def on_mask_top_changed(mask_offset): + set_mask_offset(0, mask_offset) + +def on_mask_bottom_changed(mask_offset): + set_mask_offset(1, mask_offset) + +def on_mask_left_changed(mask_offset): + set_mask_offset(2, mask_offset) + +def on_mask_right_changed(mask_offset): + set_mask_offset(3, mask_offset) + +def on_mask_erosion_changed(mask_offset): + set_mask_offset(4, mask_offset) +def on_mask_blur_changed(mask_offset): + set_mask_offset(5, mask_offset) + + +def set_mask_offset(index, mask_offset): + global SELECTED_INPUT_FACE_INDEX + + if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX: + offs = roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets + offs[index] = mask_offset + if offs[0] + offs[1] > 0.99: + offs[0] = 0.99 + offs[1] = 0.0 + if offs[2] + offs[3] > 0.99: + offs[2] = 0.99 + offs[3] = 0.0 + roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets = offs + +def on_mask_engine_changed(mask_engine): + if mask_engine == "Clip2Seg": + return gr.Textbox(interactive=True) + return gr.Textbox(interactive=False) + + + +def on_add_local_folder(folder): + files = util.get_local_files_from_folder(folder) + if files is None: + gr.Warning("Empty folder or folder not found!") + return files + + +def on_srcfile_changed(srcfiles, progress=gr.Progress()): + global SELECTION_FACES_DATA, IS_INPUT, input_faces, face_selection, last_image + + IS_INPUT = True + + if srcfiles is None or len(srcfiles) < 1: + return gr.Column(visible=False), None, ui.globals.ui_input_thumbs + + thumbs = [] + for f in srcfiles: + source_path = f.name + if source_path.lower().endswith('fsz'): + progress(0, desc="Retrieving faces from Faceset File") + unzipfolder = os.path.join(os.environ["TEMP"], 'faceset') + if os.path.isdir(unzipfolder): + files = os.listdir(unzipfolder) + for file in files: + os.remove(os.path.join(unzipfolder, file)) + else: + os.makedirs(unzipfolder) + util.mkdir_with_umask(unzipfolder) + util.unzip(source_path, unzipfolder) + is_first = True + face_set = FaceSet() + for file in os.listdir(unzipfolder): + if file.endswith(".png"): + filename = os.path.join(unzipfolder,file) + progress(0, desc="Extracting faceset") + SELECTION_FACES_DATA = extract_face_images(filename, (False, 0)) + for f in SELECTION_FACES_DATA: + face = f[0] + face.mask_offsets = (0,0,0,0,1,20) + face_set.faces.append(face) + if is_first: + image = util.convert_to_gradio(f[1]) + ui.globals.ui_input_thumbs.append(image) + is_first = False + face_set.ref_images.append(get_image_frame(filename)) + if len(face_set.faces) > 0: + if len(face_set.faces) > 1: + face_set.AverageEmbeddings() + roop.globals.INPUT_FACESETS.append(face_set) + + elif util.has_image_extension(source_path): + progress(0, desc="Retrieving faces from image") + roop.globals.source_path = source_path + SELECTION_FACES_DATA = extract_face_images(roop.globals.source_path, (False, 0)) + progress(0.5, desc="Retrieving faces from image") + for f in SELECTION_FACES_DATA: + face_set = FaceSet() + face = f[0] + face.mask_offsets = (0,0,0,0,1,20) + face_set.faces.append(face) + image = util.convert_to_gradio(f[1]) + ui.globals.ui_input_thumbs.append(image) + roop.globals.INPUT_FACESETS.append(face_set) + + progress(1.0) + + # old style with selecting input faces commented out + # if len(thumbs) < 1: + # return gr.Column(visible=False), None, ui.globals.ui_input_thumbs + # return gr.Column(visible=True), thumbs, gr.Gallery(visible=True) + + return gr.Column(visible=False), None, ui.globals.ui_input_thumbs + + +def on_select_input_face(evt: gr.SelectData): + global SELECTED_INPUT_FACE_INDEX + + SELECTED_INPUT_FACE_INDEX = evt.index + + +def remove_selected_input_face(): + global SELECTED_INPUT_FACE_INDEX + + if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX: + f = roop.globals.INPUT_FACESETS.pop(SELECTED_INPUT_FACE_INDEX) + del f + if len(ui.globals.ui_input_thumbs) > SELECTED_INPUT_FACE_INDEX: + f = ui.globals.ui_input_thumbs.pop(SELECTED_INPUT_FACE_INDEX) + del f + + return ui.globals.ui_input_thumbs + +def on_select_target_face(evt: gr.SelectData): + global SELECTED_TARGET_FACE_INDEX + + SELECTED_TARGET_FACE_INDEX = evt.index + +def remove_selected_target_face(): + if len(roop.globals.TARGET_FACES) > SELECTED_TARGET_FACE_INDEX: + f = roop.globals.TARGET_FACES.pop(SELECTED_TARGET_FACE_INDEX) + del f + if len(ui.globals.ui_target_thumbs) > SELECTED_TARGET_FACE_INDEX: + f = ui.globals.ui_target_thumbs.pop(SELECTED_TARGET_FACE_INDEX) + del f + return ui.globals.ui_target_thumbs + + + + + +def on_use_face_from_selected(files, frame_num): + global IS_INPUT, SELECTION_FACES_DATA + + IS_INPUT = False + thumbs = [] + + roop.globals.target_path = files[selected_preview_index].name + if util.is_image(roop.globals.target_path) and not roop.globals.target_path.lower().endswith(('gif')): + SELECTION_FACES_DATA = extract_face_images(roop.globals.target_path, (False, 0)) + if len(SELECTION_FACES_DATA) > 0: + for f in SELECTION_FACES_DATA: + image = util.convert_to_gradio(f[1]) + thumbs.append(image) + else: + gr.Info('No faces detected!') + roop.globals.target_path = None + + elif util.is_video(roop.globals.target_path) or roop.globals.target_path.lower().endswith(('gif')): + selected_frame = frame_num + SELECTION_FACES_DATA = extract_face_images(roop.globals.target_path, (True, selected_frame)) + if len(SELECTION_FACES_DATA) > 0: + for f in SELECTION_FACES_DATA: + image = util.convert_to_gradio(f[1]) + thumbs.append(image) + else: + gr.Info('No faces detected!') + roop.globals.target_path = None + + if len(thumbs) == 1: + roop.globals.TARGET_FACES.append(SELECTION_FACES_DATA[0][0]) + ui.globals.ui_target_thumbs.append(thumbs[0]) + return gr.Row(visible=False), None, ui.globals.ui_target_thumbs, gr.Dropdown(value='Selected face') + + return gr.Row(visible=True), thumbs, gr.Gallery(visible=True), gr.Dropdown(visible=True) + + + +def on_select_face(evt: gr.SelectData): # SelectData is a subclass of EventData + global SELECTED_FACE_INDEX + SELECTED_FACE_INDEX = evt.index + + +def on_selected_face(): + global IS_INPUT, SELECTED_FACE_INDEX, SELECTION_FACES_DATA + + fd = SELECTION_FACES_DATA[SELECTED_FACE_INDEX] + image = util.convert_to_gradio(fd[1]) + if IS_INPUT: + face_set = FaceSet() + fd[0].mask_offsets = (0,0,0,0,1,20) + face_set.faces.append(fd[0]) + roop.globals.INPUT_FACESETS.append(face_set) + ui.globals.ui_input_thumbs.append(image) + return ui.globals.ui_input_thumbs, gr.Gallery(visible=True), gr.Dropdown(visible=True) + else: + roop.globals.TARGET_FACES.append(fd[0]) + ui.globals.ui_target_thumbs.append(image) + return gr.Gallery(visible=True), ui.globals.ui_target_thumbs, gr.Dropdown(value='Selected face') + +# bt_faceselect.click(fn=on_selected_face, outputs=[dynamic_face_selection, face_selection, input_faces, target_faces]) + +def on_end_face_selection(): + return gr.Column(visible=False), None + + +def on_preview_frame_changed(frame_num, files, fake_preview, enhancer, detection, face_distance, blend_ratio, + selected_mask_engine, clip_text, no_face_action, vr_mode, auto_rotate, maskimage, show_face_area, num_steps): + global SELECTED_INPUT_FACE_INDEX, manual_masking, current_video_fps + + from roop.core import live_swap, get_processing_plugins + + manual_masking = False + mask_offsets = (0,0,0,0) + if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX: + if not hasattr(roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0], 'mask_offsets'): + roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets = mask_offsets + mask_offsets = roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets + + timeinfo = '0:00:00' + if files is None or selected_preview_index >= len(files) or frame_num is None: + return None,None, gr.Slider(info=timeinfo) + + filename = files[selected_preview_index].name + if util.is_video(filename) or filename.lower().endswith('gif'): + current_frame = get_video_frame(filename, frame_num) + if current_video_fps == 0: + current_video_fps = 1 + secs = (frame_num - 1) / current_video_fps + minutes = secs / 60 + secs = secs % 60 + hours = minutes / 60 + minutes = minutes % 60 + milliseconds = (secs - int(secs)) * 1000 + timeinfo = f"{int(hours):0>2}:{int(minutes):0>2}:{int(secs):0>2}.{int(milliseconds):0>3}" + else: + current_frame = get_image_frame(filename) + if current_frame is None: + return None, None, gr.Slider(info=timeinfo) + + layers = None + if maskimage is not None: + layers = maskimage["layers"] + + if not fake_preview or len(roop.globals.INPUT_FACESETS) < 1: + return gr.Image(value=util.convert_to_gradio(current_frame), visible=True), gr.ImageEditor(visible=False), gr.Slider(info=timeinfo) + + roop.globals.face_swap_mode = translate_swap_mode(detection) + roop.globals.selected_enhancer = enhancer + roop.globals.distance_threshold = face_distance + roop.globals.blend_ratio = blend_ratio + roop.globals.no_face_action = index_of_no_face_action(no_face_action) + roop.globals.vr_mode = vr_mode + roop.globals.autorotate_faces = auto_rotate + + mask_engine = map_mask_engine(selected_mask_engine, clip_text) + + roop.globals.execution_threads = roop.globals.CFG.max_threads + mask = layers[0] if layers is not None else None + face_index = SELECTED_INPUT_FACE_INDEX + if len(roop.globals.INPUT_FACESETS) <= face_index: + face_index = 0 + + options = ProcessOptions(get_processing_plugins(mask_engine), roop.globals.distance_threshold, roop.globals.blend_ratio, + roop.globals.face_swap_mode, face_index, clip_text, maskimage, num_steps, show_face_area) + + current_frame = live_swap(current_frame, options) + if current_frame is None: + return gr.Image(visible=True), None, gr.Slider(info=timeinfo) + return gr.Image(value=util.convert_to_gradio(current_frame), visible=True), gr.ImageEditor(visible=False), gr.Slider(info=timeinfo) + +def map_mask_engine(selected_mask_engine, clip_text): + if selected_mask_engine == "Clip2Seg": + mask_engine = "mask_clip2seg" + if clip_text is None or len(clip_text) < 1: + mask_engine = None + elif selected_mask_engine == "DFL XSeg": + mask_engine = "mask_xseg" + else: + mask_engine = None + return mask_engine + + + +def on_toggle_masking(previewimage, mask): + global manual_masking + + manual_masking = not manual_masking + if manual_masking: + layers = mask["layers"] + if len(layers) == 1: + layers = [create_blank_image(previewimage.shape[1],previewimage.shape[0])] + return gr.Image(visible=False), gr.ImageEditor(value={"background": previewimage, "layers": layers, "composite": None}, visible=True) + return gr.Image(visible=True), gr.ImageEditor(visible=False) + +def gen_processing_text(start, end): + return f'Processing frame range [{start} - {end}]' + +def on_set_frame(sender:str, frame_num): + global selected_preview_index, list_files_process + + idx = selected_preview_index + if list_files_process[idx].endframe == 0: + return gen_processing_text(0,0) + + start = list_files_process[idx].startframe + end = list_files_process[idx].endframe + if sender.lower().endswith('start'): + list_files_process[idx].startframe = min(frame_num, end) + else: + list_files_process[idx].endframe = max(frame_num, start) + + return gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe) + + + +def on_preview_mask(frame_num, files, clip_text, mask_engine): + from roop.core import live_swap, get_processing_plugins + global is_processing + + if is_processing or files is None or selected_preview_index >= len(files) or clip_text is None or frame_num is None: + return None + + filename = files[selected_preview_index].name + if util.is_video(filename) or filename.lower().endswith('gif'): + current_frame = get_video_frame(filename, frame_num + ) + else: + current_frame = get_image_frame(filename) + if current_frame is None or mask_engine is None: + return None + if mask_engine == "Clip2Seg": + mask_engine = "mask_clip2seg" + if clip_text is None or len(clip_text) < 1: + mask_engine = None + elif mask_engine == "DFL XSeg": + mask_engine = "mask_xseg" + options = ProcessOptions(get_processing_plugins(mask_engine), roop.globals.distance_threshold, roop.globals.blend_ratio, + "all", 0, clip_text, None, 0, False, True) + + current_frame = live_swap(current_frame, options) + return util.convert_to_gradio(current_frame) + + + +def on_clear_input_faces(): + ui.globals.ui_input_thumbs.clear() + roop.globals.INPUT_FACESETS.clear() + return ui.globals.ui_input_thumbs + +def on_clear_destfiles(): + roop.globals.TARGET_FACES.clear() + ui.globals.ui_target_thumbs.clear() + return ui.globals.ui_target_thumbs, gr.Dropdown(value="First found") + + +def index_of_no_face_action(dropdown_text): + global no_face_choices + + return no_face_choices.index(dropdown_text) + +def translate_swap_mode(dropdown_text): + if dropdown_text == "Selected face": + return "selected" + elif dropdown_text == "First found": + return "first" + elif dropdown_text == "All female": + return "all_female" + elif dropdown_text == "All male": + return "all_male" + + return "all" + + + +def start_swap( enhancer, detection, keep_frames, wait_after_extraction, skip_audio, face_distance, blend_ratio, + selected_mask_engine, clip_text, processing_method, no_face_action, vr_mode, autorotate, num_swap_steps, imagemask, progress=gr.Progress()): + from ui.main import prepare_environment + from roop.core import batch_process_regular + global is_processing, list_files_process + + if list_files_process is None or len(list_files_process) <= 0: + return gr.Button(variant="primary"), None, None + + if roop.globals.CFG.clear_output: + shutil.rmtree(roop.globals.output_path) + + if not util.is_installed("ffmpeg"): + msg = "ffmpeg is not installed! No video processing possible." + gr.Warning(msg) + + prepare_environment() + + roop.globals.selected_enhancer = enhancer + roop.globals.target_path = None + roop.globals.distance_threshold = face_distance + roop.globals.blend_ratio = blend_ratio + roop.globals.keep_frames = keep_frames + roop.globals.wait_after_extraction = wait_after_extraction + roop.globals.skip_audio = skip_audio + roop.globals.face_swap_mode = translate_swap_mode(detection) + roop.globals.no_face_action = index_of_no_face_action(no_face_action) + roop.globals.vr_mode = vr_mode + roop.globals.autorotate_faces = autorotate + mask_engine = map_mask_engine(selected_mask_engine, clip_text) + + if roop.globals.face_swap_mode == 'selected': + if len(roop.globals.TARGET_FACES) < 1: + gr.Error('No Target Face selected!') + return gr.Button(variant="primary"), None, None + + is_processing = True + yield gr.Button(variant="secondary", interactive=False), gr.Button(variant="primary", interactive=True), None + roop.globals.execution_threads = roop.globals.CFG.max_threads + roop.globals.video_encoder = roop.globals.CFG.output_video_codec + roop.globals.video_quality = roop.globals.CFG.video_quality + roop.globals.max_memory = roop.globals.CFG.memory_limit if roop.globals.CFG.memory_limit > 0 else None + + batch_process_regular(list_files_process, mask_engine, clip_text, processing_method == "In-Memory processing", imagemask, num_swap_steps, progress, SELECTED_INPUT_FACE_INDEX) + is_processing = False + outdir = pathlib.Path(roop.globals.output_path) + outfiles = [str(item) for item in outdir.rglob("*") if item.is_file()] + if len(outfiles) > 0: + yield gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),gr.Files(value=outfiles) + else: + yield gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),None + + +def stop_swap(): + roop.globals.processing = False + gr.Info('Aborting processing - please wait for the remaining threads to be stopped') + return gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),None + + +def on_fps_changed(fps): + global selected_preview_index, list_files_process + + if len(list_files_process) < 1 or list_files_process[selected_preview_index].endframe < 1: + return + list_files_process[selected_preview_index].fps = fps + + +def on_destfiles_changed(destfiles): + global selected_preview_index, list_files_process, current_video_fps + + if destfiles is None or len(destfiles) < 1: + list_files_process.clear() + return gr.Slider(value=1, maximum=1, info='0:00:00'), '' + + for f in destfiles: + list_files_process.append(ProcessEntry(f.name, 0,0, 0)) + + selected_preview_index = 0 + idx = selected_preview_index + + filename = list_files_process[idx].filename + + if util.is_video(filename) or filename.lower().endswith('gif'): + total_frames = get_video_frame_total(filename) + current_video_fps = util.detect_fps(filename) + else: + total_frames = 1 + list_files_process[idx].endframe = total_frames + if total_frames > 1: + return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe) + return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), '' + + + + +def on_destfiles_selected(evt: gr.SelectData): + global selected_preview_index, list_files_process, current_video_fps + + if evt is not None: + selected_preview_index = evt.index + idx = selected_preview_index + filename = list_files_process[idx].filename + fps = list_files_process[idx].fps + if util.is_video(filename) or filename.lower().endswith('gif'): + total_frames = get_video_frame_total(filename) + current_video_fps = util.detect_fps(filename) + if list_files_process[idx].endframe == 0: + list_files_process[idx].endframe = total_frames + else: + total_frames = 1 + + if total_frames > 1: + return gr.Slider(value=list_files_process[idx].startframe, maximum=total_frames, info='0:00:00'), gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe), fps + return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), gen_processing_text(0,0), fps + + + +def on_resultfiles_selected(evt: gr.SelectData, files): + selected_index = evt.index + filename = files[selected_index].name + return display_output(filename) + +def on_resultfiles_finished(files): + selected_index = 0 + if files is None or len(files) < 1: + return None, None + + filename = files[selected_index].name + return display_output(filename) + + +def display_output(filename): + if util.is_video(filename) and roop.globals.CFG.output_show_video: + return gr.Image(visible=False), gr.Video(visible=True, value=filename) + else: + if util.is_video(filename) or filename.lower().endswith('gif'): + current_frame = get_video_frame(filename) + else: + current_frame = get_image_frame(filename) + return gr.Image(visible=True, value=util.convert_to_gradio(current_frame)), gr.Video(visible=False) diff --git a/roop-unleashed/ui/tabs/livecam_tab.py b/roop-unleashed/ui/tabs/livecam_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b5a228f8a324291be072790d22828350109b12 --- /dev/null +++ b/roop-unleashed/ui/tabs/livecam_tab.py @@ -0,0 +1,54 @@ +import gradio as gr +import roop.globals +import ui.globals + + +camera_frame = None + +def livecam_tab(): + with gr.Tab("๐ŸŽฅ Live Cam"): + with gr.Row(variant='panel'): + gr.Markdown(""" + This feature will allow you to use your physical webcam and apply the selected faces to the stream. + You can also forward the stream to a virtual camera, which can be used in video calls or streaming software.
+ Supported are: v4l2loopback (linux), OBS Virtual Camera (macOS/Windows) and unitycapture (Windows).
+ **Please note:** to change the face or any other settings you need to stop and restart a running live cam. + """) + + with gr.Row(variant='panel'): + with gr.Column(): + bt_start = gr.Button("โ–ถ Start", variant='primary') + with gr.Column(): + bt_stop = gr.Button("โน Stop", variant='secondary', interactive=False) + with gr.Column(): + camera_num = gr.Slider(0, 8, value=0, label="Camera Number", step=1.0, interactive=True) + cb_obs = gr.Checkbox(label="Forward stream to virtual camera", interactive=True) + with gr.Column(): + dd_reso = gr.Dropdown(choices=["640x480","1280x720", "1920x1080"], value="1280x720", label="Fake Camera Resolution", interactive=True) + + with gr.Row(): + fake_cam_image = gr.Image(label='Fake Camera Output', interactive=False) + + start_event = bt_start.click(fn=start_cam, inputs=[cb_obs, camera_num, dd_reso, ui.globals.ui_selected_enhancer, ui.globals.ui_blend_ratio],outputs=[bt_start, bt_stop,fake_cam_image]) + bt_stop.click(fn=stop_swap, cancels=[start_event], outputs=[bt_start, bt_stop], queue=False) + + +def start_cam(stream_to_obs, cam, reso, enhancer, blend_ratio): + from roop.virtualcam import start_virtual_cam + from roop.utilities import convert_to_gradio + + start_virtual_cam(stream_to_obs, cam, reso) + roop.globals.selected_enhancer = enhancer + roop.globals.blend_ratio = blend_ratio + while True: + yield gr.Button(interactive=False), gr.Button(interactive=True), convert_to_gradio(ui.globals.ui_camera_frame) + + +def stop_swap(): + from roop.virtualcam import stop_virtual_cam + stop_virtual_cam() + return gr.Button(interactive=True), gr.Button(interactive=False) + + + + diff --git a/roop-unleashed/ui/tabs/settings_tab.py b/roop-unleashed/ui/tabs/settings_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b34e91ac3090946afa7ec1bd2dd6ef29767f54 --- /dev/null +++ b/roop-unleashed/ui/tabs/settings_tab.py @@ -0,0 +1,129 @@ +import shutil +import os +import gradio as gr +import roop.globals +import ui.globals + +available_themes = ["Default", "gradio/glass", "gradio/monochrome", "gradio/seafoam", "gradio/soft", "gstaff/xkcd", "freddyaboulton/dracula_revamped", "ysharma/steampunk"] +image_formats = ['jpg','png', 'webp'] +video_formats = ['avi','mkv', 'mp4', 'webm'] +video_codecs = ['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc'] +providerlist = None + +settings_controls = [] + +def settings_tab(): + from roop.core import suggest_execution_providers + global providerlist + + providerlist = suggest_execution_providers() + with gr.Tab("โš™ Settings"): + with gr.Row(): + with gr.Column(): + themes = gr.Dropdown(available_themes, label="Theme", info="Change needs complete restart", value=roop.globals.CFG.selected_theme) + with gr.Column(): + settings_controls.append(gr.Checkbox(label="Public Server", value=roop.globals.CFG.server_share, elem_id='server_share', interactive=True)) + settings_controls.append(gr.Checkbox(label='Clear output folder before each run', value=roop.globals.CFG.clear_output, elem_id='clear_output', interactive=True)) + output_template = gr.Textbox(label="Filename Output Template", info="(file extension is added automatically)", lines=1, placeholder='{file}_{time}', value=roop.globals.CFG.output_template) + with gr.Column(): + input_server_name = gr.Textbox(label="Server Name", lines=1, info="Leave blank to run locally", value=roop.globals.CFG.server_name) + with gr.Column(): + input_server_port = gr.Number(label="Server Port", precision=0, info="Leave at 0 to use default", value=roop.globals.CFG.server_port) + with gr.Row(): + with gr.Column(): + settings_controls.append(gr.Dropdown(providerlist, label="Provider", value=roop.globals.CFG.provider, elem_id='provider', interactive=True)) + chk_det_size = gr.Checkbox(label="Use default Det-Size", value=True, elem_id='default_det_size', interactive=True) + settings_controls.append(gr.Checkbox(label="Force CPU for Face Analyser", value=roop.globals.CFG.force_cpu, elem_id='force_cpu', interactive=True)) + max_threads = gr.Slider(1, 32, value=roop.globals.CFG.max_threads, label="Max. Number of Threads", info='default: 3', step=1.0, interactive=True) + with gr.Column(): + memory_limit = gr.Slider(0, 128, value=roop.globals.CFG.memory_limit, label="Max. Memory to use (Gb)", info='0 meaning no limit', step=1.0, interactive=True) + settings_controls.append(gr.Dropdown(image_formats, label="Image Output Format", info='default: png', value=roop.globals.CFG.output_image_format, elem_id='output_image_format', interactive=True)) + with gr.Column(): + settings_controls.append(gr.Dropdown(video_codecs, label="Video Codec", info='default: libx264', value=roop.globals.CFG.output_video_codec, elem_id='output_video_codec', interactive=True)) + settings_controls.append(gr.Dropdown(video_formats, label="Video Output Format", info='default: mp4', value=roop.globals.CFG.output_video_format, elem_id='output_video_format', interactive=True)) + video_quality = gr.Slider(0, 100, value=roop.globals.CFG.video_quality, label="Video Quality (crf)", info='default: 14', step=1.0, interactive=True) + with gr.Column(): + with gr.Group(): + settings_controls.append(gr.Checkbox(label='Use OS temp folder', value=roop.globals.CFG.use_os_temp_folder, elem_id='use_os_temp_folder', interactive=True)) + settings_controls.append(gr.Checkbox(label='Show video in browser (re-encodes output)', value=roop.globals.CFG.output_show_video, elem_id='output_show_video', interactive=True)) + button_apply_restart = gr.Button("Restart Server", variant='primary') + button_clean_temp = gr.Button("Clean temp folder") + button_apply_settings = gr.Button("Apply Settings") + + chk_det_size.select(fn=on_option_changed) + + # Settings + for s in settings_controls: + s.select(fn=on_settings_changed) + max_threads.input(fn=lambda a,b='max_threads':on_settings_changed_misc(a,b), inputs=[max_threads]) + memory_limit.input(fn=lambda a,b='memory_limit':on_settings_changed_misc(a,b), inputs=[memory_limit]) + video_quality.input(fn=lambda a,b='video_quality':on_settings_changed_misc(a,b), inputs=[video_quality]) + + # button_clean_temp.click(fn=clean_temp, outputs=[bt_srcfiles, input_faces, target_faces, bt_destfiles]) + button_clean_temp.click(fn=clean_temp) + button_apply_settings.click(apply_settings, inputs=[themes, input_server_name, input_server_port, output_template]) + button_apply_restart.click(restart) + + +def on_option_changed(evt: gr.SelectData): + attribname = evt.target.elem_id + if isinstance(evt.target, gr.Checkbox): + if hasattr(roop.globals, attribname): + setattr(roop.globals, attribname, evt.selected) + return + elif isinstance(evt.target, gr.Dropdown): + if hasattr(roop.globals, attribname): + setattr(roop.globals, attribname, evt.value) + return + raise gr.Error(f'Unhandled Setting for {evt.target}') + + +def on_settings_changed_misc(new_val, attribname): + if hasattr(roop.globals.CFG, attribname): + setattr(roop.globals.CFG, attribname, new_val) + else: + print("Didn't find attrib!") + + + +def on_settings_changed(evt: gr.SelectData): + attribname = evt.target.elem_id + if isinstance(evt.target, gr.Checkbox): + if hasattr(roop.globals.CFG, attribname): + setattr(roop.globals.CFG, attribname, evt.selected) + return + elif isinstance(evt.target, gr.Dropdown): + if hasattr(roop.globals.CFG, attribname): + setattr(roop.globals.CFG, attribname, evt.value) + return + + raise gr.Error(f'Unhandled Setting for {evt.target}') + +def clean_temp(): + from ui.main import prepare_environment + + if not roop.globals.CFG.use_os_temp_folder: + shutil.rmtree(os.environ["TEMP"]) + prepare_environment() + + ui.globals.ui_input_thumbs.clear() + roop.globals.INPUT_FACESETS.clear() + roop.globals.TARGET_FACES.clear() + ui.globals.ui_target_thumbs = [] + gr.Info('Temp Files removed') + return None,None,None,None + + +def apply_settings(themes, input_server_name, input_server_port, output_template): + from ui.main import show_msg + + roop.globals.CFG.selected_theme = themes + roop.globals.CFG.server_name = input_server_name + roop.globals.CFG.server_port = input_server_port + roop.globals.CFG.output_template = output_template + roop.globals.CFG.save() + show_msg('Settings saved') + + +def restart(): + ui.globals.ui_restart_server = True diff --git a/roop/FaceSet.py b/roop/FaceSet.py new file mode 100644 index 0000000000000000000000000000000000000000..9e426219fe4265290883a026fbde2d0513d5d554 --- /dev/null +++ b/roop/FaceSet.py @@ -0,0 +1,20 @@ +import numpy as np + +class FaceSet: + faces = [] + ref_images = [] + embedding_average = 'None' + embeddings_backup = None + + def __init__(self): + self.faces = [] + self.ref_images = [] + self.embeddings_backup = None + + def AverageEmbeddings(self): + if len(self.faces) > 1 and self.embeddings_backup is None: + self.embeddings_backup = self.faces[0]['embedding'] + embeddings = [face.embedding for face in self.faces] + + self.faces[0]['embedding'] = np.mean(embeddings, axis=0) + # try median too? diff --git a/roop/ProcessEntry.py b/roop/ProcessEntry.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd53239463a14769954a10f1371d332bd88e05d --- /dev/null +++ b/roop/ProcessEntry.py @@ -0,0 +1,7 @@ +class ProcessEntry: + def __init__(self, filename: str, start: int, end: int, fps: float): + self.filename = filename + self.finalname = None + self.startframe = start + self.endframe = end + self.fps = fps \ No newline at end of file diff --git a/roop/ProcessMgr.py b/roop/ProcessMgr.py new file mode 100644 index 0000000000000000000000000000000000000000..285089389f1ec71c0b64a18098eb9657cb3fabb1 --- /dev/null +++ b/roop/ProcessMgr.py @@ -0,0 +1,701 @@ +import os +import cv2 +import numpy as np +import psutil + +from enum import Enum +from roop.ProcessOptions import ProcessOptions + +from roop.face_util import get_first_face, get_all_faces, rotate_image_180, rotate_anticlockwise, rotate_clockwise, clamp_cut_values +from roop.utilities import compute_cosine_distance, get_device, str_to_class +import roop.vr_util as vr + +from typing import Any, List, Callable +from roop.typing import Frame, Face +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Thread, Lock +from queue import Queue +from tqdm import tqdm +from roop.ffmpeg_writer import FFMPEG_VideoWriter +import roop.globals + +# Poor man's enum to be able to compare to int +class eNoFaceAction(): + USE_ORIGINAL_FRAME = 0 + RETRY_ROTATED = 1 + SKIP_FRAME = 2 + SKIP_FRAME_IF_DISSIMILAR = 3 + + + +def create_queue(temp_frame_paths: List[str]) -> Queue[str]: + queue: Queue[str] = Queue() + for frame_path in temp_frame_paths: + queue.put(frame_path) + return queue + + +def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]: + queues = [] + for _ in range(queue_per_future): + if not queue.empty(): + queues.append(queue.get()) + return queues + + +class ProcessMgr(): + input_face_datas = [] + target_face_datas = [] + + imagemask = None + + processors = [] + options : ProcessOptions = None + + num_threads = 1 + current_index = 0 + processing_threads = 1 + buffer_wait_time = 0.1 + + lock = Lock() + + frames_queue = None + processed_queue = None + + videowriter= None + + progress_gradio = None + total_frames = 0 + + + + + plugins = { + 'faceswap' : 'FaceSwapInsightFace', + 'mask_clip2seg' : 'Mask_Clip2Seg', + 'mask_xseg' : 'Mask_XSeg', + 'codeformer' : 'Enhance_CodeFormer', + 'gfpgan' : 'Enhance_GFPGAN', + 'dmdnet' : 'Enhance_DMDNet', + 'gpen' : 'Enhance_GPEN', + 'restoreformer++' : 'Enhance_RestoreFormerPPlus', + 'colorizer' : 'Frame_Colorizer', + 'filter_generic' : 'Frame_Filter', + 'removebg' : 'Frame_Masking', + 'upscale' : 'Frame_Upscale' + } + + def __init__(self, progress): + if progress is not None: + self.progress_gradio = progress + + def reuseOldProcessor(self, name:str): + for p in self.processors: + if p.processorname == name: + return p + + return None + + + def initialize(self, input_faces, target_faces, options): + self.input_face_datas = input_faces + self.target_face_datas = target_faces + self.options = options + devicename = get_device() + + roop.globals.g_desired_face_analysis=["landmark_3d_68", "landmark_2d_106","detection","recognition"] + if options.swap_mode == "all_female" or options.swap_mode == "all_male": + roop.globals.g_desired_face_analysis.append("genderage") + + for p in self.processors: + newp = next((x for x in options.processors.keys() if x == p.processorname), None) + if newp is None: + p.Release() + del p + + newprocessors = [] + for key, extoption in options.processors.items(): + p = self.reuseOldProcessor(key) + if p is None: + classname = self.plugins[key] + module = 'roop.processors.' + classname + p = str_to_class(module, classname) + if p is not None: + extoption.update({"devicename": devicename}) + p.Initialize(extoption) + newprocessors.append(p) + else: + print(f"Not using {module}") + self.processors = newprocessors + + + + if isinstance(self.options.imagemask, dict) and self.options.imagemask.get("layers") and len(self.options.imagemask["layers"]) > 0: + self.options.imagemask = self.options.imagemask.get("layers")[0] + # Get rid of alpha + self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_RGBA2GRAY) + if np.any(self.options.imagemask): + mo = self.input_face_datas[0].faces[0].mask_offsets + self.options.imagemask = self.blur_area(self.options.imagemask, mo[4], mo[5]) + self.options.imagemask = self.options.imagemask.astype(np.float32) / 255 + self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_GRAY2RGB) + else: + self.options.imagemask = None + + self.options.frame_processing = False + for p in self.processors: + if p.type.startswith("frame_"): + self.options.frame_processing = True + + + + + + + def run_batch(self, source_files, target_files, threads:int = 1): + progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' + self.total_frames = len(source_files) + self.num_threads = threads + with tqdm(total=self.total_frames, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: + with ThreadPoolExecutor(max_workers=threads) as executor: + futures = [] + queue = create_queue(source_files) + queue_per_future = max(len(source_files) // threads, 1) + while not queue.empty(): + future = executor.submit(self.process_frames, source_files, target_files, pick_queue(queue, queue_per_future), lambda: self.update_progress(progress)) + futures.append(future) + for future in as_completed(futures): + future.result() + + + def process_frames(self, source_files: List[str], target_files: List[str], current_files, update: Callable[[], None]) -> None: + for f in current_files: + if not roop.globals.processing: + return + + # Decode the byte array into an OpenCV image + temp_frame = cv2.imdecode(np.fromfile(f, dtype=np.uint8), cv2.IMREAD_COLOR) + if temp_frame is not None: + if self.options.frame_processing: + for p in self.processors: + frame = p.Run(temp_frame) + resimg = frame + else: + resimg = self.process_frame(temp_frame) + if resimg is not None: + i = source_files.index(f) + cv2.imwrite(target_files[i], resimg) + if update: + update() + + + + def read_frames_thread(self, cap, frame_start, frame_end, num_threads): + num_frame = 0 + total_num = frame_end - frame_start + if frame_start > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES,frame_start) + + while True and roop.globals.processing: + ret, frame = cap.read() + if not ret: + break + + self.frames_queue[num_frame % num_threads].put(frame, block=True) + num_frame += 1 + if num_frame == total_num: + break + + for i in range(num_threads): + self.frames_queue[i].put(None) + + + + def process_videoframes(self, threadindex, progress) -> None: + while True: + frame = self.frames_queue[threadindex].get() + if frame is None: + self.processing_threads -= 1 + self.processed_queue[threadindex].put((False, None)) + return + else: + if self.options.frame_processing: + for p in self.processors: + frame = p.Run(frame) + resimg = frame + else: + resimg = self.process_frame(frame) + self.processed_queue[threadindex].put((True, resimg)) + del frame + progress() + + + def write_frames_thread(self): + nextindex = 0 + num_producers = self.num_threads + + while True: + process, frame = self.processed_queue[nextindex % self.num_threads].get() + nextindex += 1 + if frame is not None: + self.videowriter.write_frame(frame) + del frame + elif process == False: + num_producers -= 1 + if num_producers < 1: + return + + + + def run_batch_inmem(self, source_video, target_video, frame_start, frame_end, fps, threads:int = 1, skip_audio=False): + cap = cv2.VideoCapture(source_video) + # frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_count = (frame_end - frame_start) + 1 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + processed_resolution = None + for p in self.processors: + if hasattr(p, 'getProcessedResolution'): + processed_resolution = p.getProcessedResolution(width, height) + print(f"Processed resolution: {processed_resolution}") + if processed_resolution is not None: + width = processed_resolution[0] + height = processed_resolution[1] + + + self.total_frames = frame_count + self.num_threads = threads + + self.processing_threads = self.num_threads + self.frames_queue = [] + self.processed_queue = [] + for _ in range(threads): + self.frames_queue.append(Queue(1)) + self.processed_queue.append(Queue(1)) + + self.videowriter = FFMPEG_VideoWriter(target_video, (width, height), fps, codec=roop.globals.video_encoder, crf=roop.globals.video_quality, audiofile=None) + + readthread = Thread(target=self.read_frames_thread, args=(cap, frame_start, frame_end, threads)) + readthread.start() + + writethread = Thread(target=self.write_frames_thread) + writethread.start() + + progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' + with tqdm(total=self.total_frames, desc='Processing', unit='frames', dynamic_ncols=True, bar_format=progress_bar_format) as progress: + with ThreadPoolExecutor(thread_name_prefix='swap_proc', max_workers=self.num_threads) as executor: + futures = [] + + for threadindex in range(threads): + future = executor.submit(self.process_videoframes, threadindex, lambda: self.update_progress(progress)) + futures.append(future) + + for future in as_completed(futures): + future.result() + # wait for the task to complete + readthread.join() + writethread.join() + cap.release() + self.videowriter.close() + self.frames_queue.clear() + self.processed_queue.clear() + + + + + def update_progress(self, progress: Any = None) -> None: + process = psutil.Process(os.getpid()) + memory_usage = process.memory_info().rss / 1024 / 1024 / 1024 + progress.set_postfix({ + 'memory_usage': '{:.2f}'.format(memory_usage).zfill(5) + 'GB', + 'execution_threads': self.num_threads + }) + progress.update(1) + if self.progress_gradio is not None: + self.progress_gradio((progress.n, self.total_frames), desc='Processing', total=self.total_frames, unit='frames') + + +# https://github.com/deepinsight/insightface#third-party-re-implementation-of-arcface +# https://github.com/deepinsight/insightface/blob/master/alignment/coordinate_reg/image_infer.py +# https://github.com/deepinsight/insightface/issues/1350 +# https://github.com/linghu8812/tensorrt_inference + + + def process_frame(self, frame:Frame): + if len(self.input_face_datas) < 1 and not self.options.show_face_masking: + return frame + temp_frame = frame.copy() + num_swapped, temp_frame = self.swap_faces(frame, temp_frame) + if num_swapped > 0: + if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME_IF_DISSIMILAR: + if len(self.input_face_datas) > num_swapped: + return None + return temp_frame + if roop.globals.no_face_action == eNoFaceAction.USE_ORIGINAL_FRAME: + return frame + if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME: + #This only works with in-mem processing, as it simply skips the frame. + #For 'extract frames' it simply leaves the unprocessed frame unprocessed and it gets used in the final output by ffmpeg. + #If we could delete that frame here, that'd work but that might cause ffmpeg to fail unless the frames are renamed, and I don't think we have the info on what frame it actually is????? + #alternatively, it could mark all the necessary frames for deletion, delete them at the end, then rename the remaining frames that might work? + return None + else: + return self.retry_rotated(frame) + + def retry_rotated(self, frame): + copyframe = frame.copy() + copyframe = rotate_clockwise(copyframe) + temp_frame = copyframe.copy() + num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame) + if num_swapped > 0: + return rotate_anticlockwise(temp_frame) + + copyframe = frame.copy() + copyframe = rotate_anticlockwise(copyframe) + temp_frame = copyframe.copy() + num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame) + if num_swapped > 0: + return rotate_clockwise(temp_frame) + del copyframe + return frame + + + + def swap_faces(self, frame, temp_frame): + num_faces_found = 0 + + if self.options.swap_mode == "first": + face = get_first_face(frame) + + if face is None: + return num_faces_found, frame + + num_faces_found += 1 + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + else: + faces = get_all_faces(frame) + if faces is None: + return num_faces_found, frame + + if self.options.swap_mode == "all": + for face in faces: + num_faces_found += 1 + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + del face + + elif self.options.swap_mode == "selected": + num_targetfaces = len(self.target_face_datas) + use_index = num_targetfaces == 1 + for i,tf in enumerate(self.target_face_datas): + for face in faces: + if compute_cosine_distance(tf.embedding, face.embedding) <= self.options.face_distance_threshold: + if i < len(self.input_face_datas): + if use_index: + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + else: + temp_frame = self.process_face(i, face, temp_frame) + num_faces_found += 1 + del face + if not roop.globals.vr_mode and num_faces_found == num_targetfaces: + break + elif self.options.swap_mode == "all_female" or self.options.swap_mode == "all_male": + gender = 'F' if self.options.swap_mode == "all_female" else 'M' + for face in faces: + if face.sex == gender: + num_faces_found += 1 + temp_frame = self.process_face(self.options.selected_index, face, temp_frame) + del face + + if roop.globals.vr_mode and num_faces_found % 2 > 0: + # stereo image, there has to be an even number of faces + num_faces_found = 0 + return num_faces_found, frame + if num_faces_found == 0: + return num_faces_found, frame + + #maskprocessor = next((x for x in self.processors if x.type == 'mask'), None) + + if self.options.imagemask is not None and self.options.imagemask.shape == frame.shape: + temp_frame = self.simple_blend_with_mask(temp_frame, frame, self.options.imagemask) + return num_faces_found, temp_frame + + + def rotation_action(self, original_face:Face, frame:Frame): + (height, width) = frame.shape[:2] + + bounding_box_width = original_face.bbox[2] - original_face.bbox[0] + bounding_box_height = original_face.bbox[3] - original_face.bbox[1] + horizontal_face = bounding_box_width > bounding_box_height + + center_x = width // 2.0 + start_x = original_face.bbox[0] + end_x = original_face.bbox[2] + bbox_center_x = start_x + (bounding_box_width // 2.0) + + # need to leverage the array of landmarks as decribed here: + # https://github.com/deepinsight/insightface/tree/master/alignment/coordinate_reg + # basically, we should be able to check for the relative position of eyes and nose + # then use that to determine which way the face is actually facing when in a horizontal position + # and use that to determine the correct rotation_action + + forehead_x = original_face.landmark_2d_106[72][0] + chin_x = original_face.landmark_2d_106[0][0] + + if horizontal_face: + if chin_x < forehead_x: + # this is someone lying down with their face like this (: + return "rotate_anticlockwise" + elif forehead_x < chin_x: + # this is someone lying down with their face like this :) + return "rotate_clockwise" + if bbox_center_x >= center_x: + # this is someone lying down with their face in the right hand side of the frame + return "rotate_anticlockwise" + if bbox_center_x < center_x: + # this is someone lying down with their face in the left hand side of the frame + return "rotate_clockwise" + + return None + + + def auto_rotate_frame(self, original_face, frame:Frame): + target_face = original_face + original_frame = frame + + rotation_action = self.rotation_action(original_face, frame) + + if rotation_action == "rotate_anticlockwise": + #face is horizontal, rotating frame anti-clockwise and getting face bounding box from rotated frame + frame = rotate_anticlockwise(frame) + elif rotation_action == "rotate_clockwise": + #face is horizontal, rotating frame clockwise and getting face bounding box from rotated frame + frame = rotate_clockwise(frame) + + return target_face, frame, rotation_action + + + def auto_unrotate_frame(self, frame:Frame, rotation_action): + if rotation_action == "rotate_anticlockwise": + return rotate_clockwise(frame) + elif rotation_action == "rotate_clockwise": + return rotate_anticlockwise(frame) + + return frame + + + + def process_face(self,face_index, target_face:Face, frame:Frame): + from roop.face_util import align_crop + + enhanced_frame = None + if(len(self.input_face_datas) > 0): + inputface = self.input_face_datas[face_index].faces[0] + else: + inputface = None + + rotation_action = None + if roop.globals.autorotate_faces: + # check for sideways rotation of face + rotation_action = self.rotation_action(target_face, frame) + if rotation_action is not None: + (startX, startY, endX, endY) = target_face["bbox"].astype("int") + width = endX - startX + height = endY - startY + offs = int(max(width,height) * 0.25) + rotcutframe,startX, startY, endX, endY = self.cutout(frame, startX - offs, startY - offs, endX + offs, endY + offs) + if rotation_action == "rotate_anticlockwise": + rotcutframe = rotate_anticlockwise(rotcutframe) + elif rotation_action == "rotate_clockwise": + rotcutframe = rotate_clockwise(rotcutframe) + # rotate image and re-detect face to correct wonky landmarks + rotface = get_first_face(rotcutframe) + if rotface is None: + rotation_action = None + else: + saved_frame = frame.copy() + frame = rotcutframe + target_face = rotface + + + + # if roop.globals.vr_mode: + # bbox = target_face.bbox + # [orig_width, orig_height, _] = frame.shape + + # # Convert bounding box to ints + # x1, y1, x2, y2 = map(int, bbox) + + # # Determine the center of the bounding box + # x_center = (x1 + x2) / 2 + # y_center = (y1 + y2) / 2 + + # # Normalize coordinates to range [-1, 1] + # x_center_normalized = x_center / (orig_width / 2) - 1 + # y_center_normalized = y_center / (orig_width / 2) - 1 + + # # Convert normalized coordinates to spherical (theta, phi) + # theta = x_center_normalized * 180 # Theta ranges from -180 to 180 degrees + # phi = -y_center_normalized * 90 # Phi ranges from -90 to 90 degrees + + # img = vr.GetPerspective(frame, 90, theta, phi, 1280, 1280) # Generate perspective image + + fake_frame = None + aligned_img, M = align_crop(frame, target_face.kps, 128) + fake_frame = aligned_img + swap_frame = aligned_img + target_face.matrix = M + for p in self.processors: + if p.type == 'swap': + if inputface is not None: + for _ in range(0,self.options.num_swap_steps): + swap_frame = p.Run(inputface, target_face, swap_frame) + fake_frame = swap_frame + scale_factor = 0.0 + elif p.type == 'mask': + fake_frame = self.process_mask(p, aligned_img, fake_frame) + else: + enhanced_frame, scale_factor = p.Run(self.input_face_datas[face_index], target_face, fake_frame) + + upscale = 512 + orig_width = fake_frame.shape[1] + + fake_frame = cv2.resize(fake_frame, (upscale, upscale), cv2.INTER_CUBIC) + mask_offsets = (0,0,0,0,1,20) if inputface is None else inputface.mask_offsets + + + if enhanced_frame is None: + scale_factor = int(upscale / orig_width) + result = self.paste_upscale(fake_frame, fake_frame, target_face.matrix, frame, scale_factor, mask_offsets) + else: + result = self.paste_upscale(fake_frame, enhanced_frame, target_face.matrix, frame, scale_factor, mask_offsets) + + if rotation_action is not None: + fake_frame = self.auto_unrotate_frame(result, rotation_action) + return self.paste_simple(fake_frame, saved_frame, startX, startY) + + return result + + + + + def cutout(self, frame:Frame, start_x, start_y, end_x, end_y): + if start_x < 0: + start_x = 0 + if start_y < 0: + start_y = 0 + if end_x > frame.shape[1]: + end_x = frame.shape[1] + if end_y > frame.shape[0]: + end_y = frame.shape[0] + return frame[start_y:end_y, start_x:end_x], start_x, start_y, end_x, end_y + + def paste_simple(self, src:Frame, dest:Frame, start_x, start_y): + end_x = start_x + src.shape[1] + end_y = start_y + src.shape[0] + + start_x, end_x, start_y, end_y = clamp_cut_values(start_x, end_x, start_y, end_y, dest) + dest[start_y:end_y, start_x:end_x] = src + return dest + + def simple_blend_with_mask(self, image1, image2, mask): + # Blend the images + blended_image = image1.astype(np.float32) * (1.0 - mask) + image2.astype(np.float32) * mask + return blended_image.astype(np.uint8) + + + def paste_upscale(self, fake_face, upsk_face, M, target_img, scale_factor, mask_offsets): + M_scale = M * scale_factor + IM = cv2.invertAffineTransform(M_scale) + + face_matte = np.full((target_img.shape[0],target_img.shape[1]), 255, dtype=np.uint8) + # Generate white square sized as a upsk_face + img_matte = np.zeros((upsk_face.shape[0],upsk_face.shape[1]), dtype=np.uint8) + + w = img_matte.shape[1] + h = img_matte.shape[0] + + top = int(mask_offsets[0] * h) + bottom = int(h - (mask_offsets[1] * h)) + left = int(mask_offsets[2] * w) + right = int(w - (mask_offsets[3] * w)) + img_matte[top:bottom,left:right] = 255 + + # Transform white square back to target_img + img_matte = cv2.warpAffine(img_matte, IM, (target_img.shape[1], target_img.shape[0]), flags=cv2.INTER_NEAREST, borderValue=0.0) + ##Blacken the edges of face_matte by 1 pixels (so the mask in not expanded on the image edges) + img_matte[:1,:] = img_matte[-1:,:] = img_matte[:,:1] = img_matte[:,-1:] = 0 + + img_matte = self.blur_area(img_matte, mask_offsets[4], mask_offsets[5]) + #Normalize images to float values and reshape + img_matte = img_matte.astype(np.float32)/255 + face_matte = face_matte.astype(np.float32)/255 + img_matte = np.minimum(face_matte, img_matte) + if self.options.show_face_area_overlay: + # Additional steps for green overlay + green_overlay = np.zeros_like(target_img) + green_color = [0, 255, 0] # RGB for green + for i in range(3): # Apply green color where img_matte is not zero + green_overlay[:, :, i] = np.where(img_matte > 0, green_color[i], 0) ##Transform upcaled face back to target_img + img_matte = np.reshape(img_matte, [img_matte.shape[0],img_matte.shape[1],1]) + paste_face = cv2.warpAffine(upsk_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE) + if upsk_face is not fake_face: + fake_face = cv2.warpAffine(fake_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE) + paste_face = cv2.addWeighted(paste_face, self.options.blend_ratio, fake_face, 1.0 - self.options.blend_ratio, 0) + + # Re-assemble image + paste_face = img_matte * paste_face + paste_face = paste_face + (1-img_matte) * target_img.astype(np.float32) + if self.options.show_face_area_overlay: + # Overlay the green overlay on the final image + paste_face = cv2.addWeighted(paste_face.astype(np.uint8), 1 - 0.5, green_overlay, 0.5, 0) + return paste_face.astype(np.uint8) + + + def blur_area(self, img_matte, num_erosion_iterations, blur_amount): + # Detect the affine transformed white area + mask_h_inds, mask_w_inds = np.where(img_matte==255) + # Calculate the size (and diagonal size) of transformed white area width and height boundaries + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h*mask_w)) + # Calculate the kernel size for eroding img_matte by kernel (insightface empirical guess for best size was max(mask_size//10,10)) + # k = max(mask_size//12, 8) + k = max(mask_size//(blur_amount // 2) , blur_amount // 2) + kernel = np.ones((k,k),np.uint8) + img_matte = cv2.erode(img_matte,kernel,iterations = num_erosion_iterations) + #Calculate the kernel size for blurring img_matte by blur_size (insightface empirical guess for best size was max(mask_size//20, 5)) + # k = max(mask_size//24, 4) + k = max(mask_size//blur_amount, blur_amount//5) + kernel_size = (k, k) + blur_size = tuple(2*i+1 for i in kernel_size) + return cv2.GaussianBlur(img_matte, blur_size, 0) + + + def process_mask(self, processor, frame:Frame, target:Frame): + img_mask = processor.Run(frame, self.options.masking_text) + img_mask = cv2.resize(img_mask, (target.shape[1], target.shape[0])) + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + if self.options.show_face_masking: + result = (1 - img_mask) * frame.astype(np.float32) + return np.uint8(result) + + + target = target.astype(np.float32) + result = (1-img_mask) * target + result += img_mask * frame.astype(np.float32) + return np.uint8(result) + + + + + def unload_models(): + pass + + + def release_resources(self): + for p in self.processors: + p.Release() + self.processors.clear() + diff --git a/roop/ProcessOptions.py b/roop/ProcessOptions.py new file mode 100644 index 0000000000000000000000000000000000000000..296e8b243796408555a885d11548278ef6ca363c --- /dev/null +++ b/roop/ProcessOptions.py @@ -0,0 +1,13 @@ +class ProcessOptions: + + def __init__(self, processordefines:dict, face_distance, blend_ratio, swap_mode, selected_index, masking_text, imagemask, num_steps, show_face_area, show_mask=False): + self.processors = processordefines + self.face_distance_threshold = face_distance + self.blend_ratio = blend_ratio + self.swap_mode = swap_mode + self.selected_index = selected_index + self.masking_text = masking_text + self.imagemask = imagemask + self.num_swap_steps = num_steps + self.show_face_area_overlay = show_face_area + self.show_face_masking = show_mask \ No newline at end of file diff --git a/roop/__pycache__/FaceSet.cpython-310.pyc b/roop/__pycache__/FaceSet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20b0c389973dcb60a1260e27919e55236e01b130 Binary files /dev/null and b/roop/__pycache__/FaceSet.cpython-310.pyc differ diff --git a/roop/__pycache__/ProcessEntry.cpython-310.pyc b/roop/__pycache__/ProcessEntry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47dc20d22dcce386a06497cd49e4600da54cd364 Binary files /dev/null and b/roop/__pycache__/ProcessEntry.cpython-310.pyc differ diff --git a/roop/__pycache__/ProcessMgr.cpython-310.pyc b/roop/__pycache__/ProcessMgr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b0187480d30957512149ee617165a11a522b560 Binary files /dev/null and b/roop/__pycache__/ProcessMgr.cpython-310.pyc differ diff --git a/roop/__pycache__/ProcessOptions.cpython-310.pyc b/roop/__pycache__/ProcessOptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a0a5ff2a01ff19c933d486ade5803e17df1dc1b Binary files /dev/null and b/roop/__pycache__/ProcessOptions.cpython-310.pyc differ diff --git a/roop/__pycache__/__init__.cpython-310.pyc b/roop/__pycache__/__init__.cpython-310.pyc index 0f008d4528c589db44822d308729d41948f137ea..c819adf07f6fc0e20d6e81b94489ecbc74604609 100644 Binary files a/roop/__pycache__/__init__.cpython-310.pyc and b/roop/__pycache__/__init__.cpython-310.pyc differ diff --git a/roop/__pycache__/capturer.cpython-310.pyc b/roop/__pycache__/capturer.cpython-310.pyc index 926c8f693b18f6ad0c5870377b1238f073aa3d6d..b8ce285cbbd58bd2468bb615c0f02c01b22f3e66 100644 Binary files a/roop/__pycache__/capturer.cpython-310.pyc and b/roop/__pycache__/capturer.cpython-310.pyc differ diff --git a/roop/__pycache__/core.cpython-310.pyc b/roop/__pycache__/core.cpython-310.pyc index 109e8a32aa7ebb2b970499b7339a0d2b2d18a3ce..236976f3064aa105cabc315fc566febdb0334201 100644 Binary files a/roop/__pycache__/core.cpython-310.pyc and b/roop/__pycache__/core.cpython-310.pyc differ diff --git a/roop/__pycache__/face_util.cpython-310.pyc b/roop/__pycache__/face_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcb2830f446baff1378ef151f49212d234c44903 Binary files /dev/null and b/roop/__pycache__/face_util.cpython-310.pyc differ diff --git a/roop/__pycache__/ffmpeg_writer.cpython-310.pyc b/roop/__pycache__/ffmpeg_writer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d566c96f7a400b9af0750b691745c2daad84570 Binary files /dev/null and b/roop/__pycache__/ffmpeg_writer.cpython-310.pyc differ diff --git a/roop/__pycache__/globals.cpython-310.pyc b/roop/__pycache__/globals.cpython-310.pyc index c9e968afd248ca325d79901862b1a49052bc8493..960c7ba209f2c183cb4983b0f1acad0257766daf 100644 Binary files a/roop/__pycache__/globals.cpython-310.pyc and b/roop/__pycache__/globals.cpython-310.pyc differ diff --git a/roop/__pycache__/metadata.cpython-310.pyc b/roop/__pycache__/metadata.cpython-310.pyc index 207bc9d00a68c048674c7fffe64f17297f076b8d..a22e92f0d3f6c429d658d6da01c52dcfb3fbdd39 100644 Binary files a/roop/__pycache__/metadata.cpython-310.pyc and b/roop/__pycache__/metadata.cpython-310.pyc differ diff --git a/roop/__pycache__/template_parser.cpython-310.pyc b/roop/__pycache__/template_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d3dc40a001cdb50c0ab88d1dd6f188f5315447b Binary files /dev/null and b/roop/__pycache__/template_parser.cpython-310.pyc differ diff --git a/roop/__pycache__/typing.cpython-310.pyc b/roop/__pycache__/typing.cpython-310.pyc index b0fcc940e3648ecd3c0515f254171889e6e56c7d..86161bb6f7b76f6f065ef5590d1ac3d1ca0de4be 100644 Binary files a/roop/__pycache__/typing.cpython-310.pyc and b/roop/__pycache__/typing.cpython-310.pyc differ diff --git a/roop/__pycache__/util_ffmpeg.cpython-310.pyc b/roop/__pycache__/util_ffmpeg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1632f2e630b16e99163ccddc73dc885079a115e Binary files /dev/null and b/roop/__pycache__/util_ffmpeg.cpython-310.pyc differ diff --git a/roop/__pycache__/utilities.cpython-310.pyc b/roop/__pycache__/utilities.cpython-310.pyc index da78b773dd7024aa067c530879e8379fda9d2436..3af28a5a995315b143f0e64961af04be17752651 100644 Binary files a/roop/__pycache__/utilities.cpython-310.pyc and b/roop/__pycache__/utilities.cpython-310.pyc differ diff --git a/roop/__pycache__/vr_util.cpython-310.pyc b/roop/__pycache__/vr_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc7ad38a56fd93d4166a0888cc4540be471a94b Binary files /dev/null and b/roop/__pycache__/vr_util.cpython-310.pyc differ diff --git a/roop/capturer.py b/roop/capturer.py index fd49d468dd4cd45832ab9612205968207a6f45cf..6da3ac082fb0b2498e253c05ecae429f81fd1c70 100644 --- a/roop/capturer.py +++ b/roop/capturer.py @@ -1,8 +1,18 @@ -from typing import Any +from typing import Optional import cv2 +import numpy as np +from roop.typing import Frame -def get_video_frame(video_path: str, frame_number: int = 0) -> Any: +def get_image_frame(filename: str): + try: + return cv2.imdecode(np.fromfile(filename, dtype=np.uint8), cv2.IMREAD_COLOR) + except: + print(f"Exception reading {filename}") + return None + + +def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Frame]: capture = cv2.VideoCapture(video_path) frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) diff --git a/roop/core.py b/roop/core.py index 29652b16a11212d8ab6abe16ddee7e73244eba6b..58091faba06fbc5bbdb8fa3a66de909157153927 100644 --- a/roop/core.py +++ b/roop/core.py @@ -2,26 +2,29 @@ import os import sys -# single thread doubles cuda performance - needs to be set before torch import -if any(arg.startswith('--execution-provider') for arg in sys.argv): - os.environ['OMP_NUM_THREADS'] = '1' -# reduce tensorflow log level -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +import shutil +import argparse import warnings from typing import List import platform import signal -import shutil -import argparse import torch import onnxruntime -import tensorflow - +import pathlib +from time import time import roop.globals import roop.metadata -from roop.predicter import predict_image, predict_video -from roop.processors.frame.core import get_frame_processors_modules -from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path +import roop.utilities as util +import roop.util_ffmpeg as ffmpeg +from settings import Settings +from roop.face_util import extract_face_images +from roop.ProcessEntry import ProcessEntry +from roop.ProcessMgr import ProcessMgr +from roop.ProcessOptions import ProcessOptions +from roop.capturer import get_video_frame_total +from roop.FaceSet import FaceSet + +process_mgr = None if 'ROCMExecutionProvider' in roop.globals.execution_providers: del torch @@ -30,40 +33,21 @@ warnings.filterwarnings('ignore', category=FutureWarning, module='insightface') warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') -def parse_args() -> None: - signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) - program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100)) - program.add_argument('-s', '--source', help='select an source image', dest='source_path') - program.add_argument('-t', '--target', help='select an target image or video', dest='target_path') - program.add_argument('-o', '--output', help='select output file or directory', dest='output_path') - program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+') - program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False) - program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True) - program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False) - program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False) - program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9']) - program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]') - program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory()) - program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+') - program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads()) - program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}') - - args = program.parse_args() - - roop.globals.source_path = args.source_path - roop.globals.target_path = args.target_path - roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path) - roop.globals.frame_processors = args.frame_processor - roop.globals.headless = args.source_path or args.target_path or args.output_path - roop.globals.keep_fps = args.keep_fps - roop.globals.keep_audio = args.keep_audio - roop.globals.keep_frames = args.keep_frames - roop.globals.many_faces = args.many_faces - roop.globals.video_encoder = args.video_encoder - roop.globals.video_quality = args.video_quality - roop.globals.max_memory = args.max_memory - roop.globals.execution_providers = decode_execution_providers(args.execution_provider) - roop.globals.execution_threads = args.execution_threads +def parse_args(): + parser = argparse.ArgumentParser(description="Run Roop from the command line") + parser.add_argument('--source_path', type=str, required=True, help="Path to the source file") + parser.add_argument('--target_path', type=str, required=True, help="Path to the target file") + parser.add_argument('--output_path', type=str, required=True, help="Path to save the output file") + parser.add_argument('--execution_provider', type=str, default='CPUExecutionProvider', help="Execution provider for ONNX runtime") + parser.add_argument('--max_memory', type=int, default=None, help="Max memory to use (in GB)") + parser.add_argument('--distance_threshold', type=float, default=0.6, help="Distance threshold for face matching") + parser.add_argument('--blend_ratio', type=float, default=0.5, help="Blend ratio for face swapping") + parser.add_argument('--face_swap_mode', type=str, default='replace', help="Face swap mode") + parser.add_argument('--output_image_format', type=str, default='png', help="Output image format") + parser.add_argument('--output_video_format', type=str, default='mp4', help="Output video format") + parser.add_argument('--execution_threads', type=int, default=8, help="Number of threads to use for execution") + parser.add_argument('--skip_audio', action='store_true', help="Skip audio when processing video") + return parser.parse_args() def encode_execution_providers(execution_providers: List[str]) -> List[str]: @@ -77,8 +61,8 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]: def suggest_max_memory() -> int: if platform.system().lower() == 'darwin': - return 10 - return 14 + return 4 + return 16 def suggest_execution_providers() -> List[str]: @@ -94,12 +78,6 @@ def suggest_execution_threads() -> int: def limit_resources() -> None: - # prevent tensorflow memory leak - gpus = tensorflow.config.experimental.list_physical_devices('GPU') - for gpu in gpus: - tensorflow.config.experimental.set_virtual_device_configuration(gpu, [ - tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024) - ]) # limit memory usage if roop.globals.max_memory: memory = roop.globals.max_memory * 1024 ** 3 @@ -107,7 +85,7 @@ def limit_resources() -> None: memory = roop.globals.max_memory * 1024 ** 6 if platform.system().lower() == 'windows': import ctypes - kernel32 = ctypes.windll.kernel32 + kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined] kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory)) else: import resource @@ -115,95 +93,314 @@ def limit_resources() -> None: def release_resources() -> None: - if 'CUDAExecutionProvider' in roop.globals.execution_providers: - torch.cuda.empty_cache() + import gc + global process_mgr + + if process_mgr is not None: + process_mgr.release_resources() + process_mgr = None + + gc.collect() def pre_check() -> bool: if sys.version_info < (3, 9): update_status('Python version is not supported - please upgrade to 3.9 or higher.') return False + + download_directory_path = util.resolve_relative_path('../models') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/inswapper_128.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GFPGANv1.4.onnx']) + util.conditional_download(download_directory_path, ['https://github.com/csxmli2016/DMDNet/releases/download/v1/DMDNet.pth']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GPEN-BFR-512.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/restoreformer_plus_plus.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/xseg.onnx']) + download_directory_path = util.resolve_relative_path('../models/CLIP') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/rd64-uni-refined.pth']) + download_directory_path = util.resolve_relative_path('../models/CodeFormer') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/CodeFormerv0.1.onnx']) + download_directory_path = util.resolve_relative_path('../models/Frame') + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_artistic.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_stable.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/isnet-general-use.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x4.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x2.onnx']) + util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/lsdir_x4.onnx']) + if not shutil.which('ffmpeg'): update_status('ffmpeg is not installed.') - return False return True -def update_status(message: str, scope: str = 'ROOP.CORE') -> None: - print(f'[{scope}] {message}') - - -def start() -> None: - for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): - if not frame_processor.pre_start(): - return - # process image to image - if has_image_extension(roop.globals.target_path): - if predict_image(roop.globals.target_path): - destroy() - shutil.copy2(roop.globals.target_path, roop.globals.output_path) - for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): - update_status('Progressing...', frame_processor.NAME) - frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path) - frame_processor.post_process() - release_resources() - if is_image(roop.globals.target_path): - update_status('Processing to image succeed!') - else: - update_status('Processing to image failed!') - return - # process image to videos - if predict_video(roop.globals.target_path): - destroy() - update_status('Creating temp resources...') - create_temp(roop.globals.target_path) - update_status('Extracting frames...') - extract_frames(roop.globals.target_path) - temp_frame_paths = get_temp_frame_paths(roop.globals.target_path) - for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): - update_status('Progressing...', frame_processor.NAME) - frame_processor.process_video(roop.globals.source_path, temp_frame_paths) - frame_processor.post_process() - release_resources() - # handles fps - if roop.globals.keep_fps: - update_status('Detecting fps...') - fps = detect_fps(roop.globals.target_path) - update_status(f'Creating video with {fps} fps...') - create_video(roop.globals.target_path, fps) - else: - update_status('Creating video with 30.0 fps...') - create_video(roop.globals.target_path) - # handle audio - if roop.globals.keep_audio: - if roop.globals.keep_fps: - update_status('Restoring audio...') - else: - update_status('Restoring audio might cause issues as fps are not kept...') - restore_audio(roop.globals.target_path, roop.globals.output_path) - else: - move_temp(roop.globals.target_path, roop.globals.output_path) - # clean and validate - clean_temp(roop.globals.target_path) - if is_video(roop.globals.target_path): - update_status('Processing to video succeed!') - else: - update_status('Processing to video failed!') +def update_status(message: str) -> None: + print(message) + + +def get_processing_plugins(masking_engine): + processors = {"faceswap": {}} + if masking_engine is not None: + processors.update({masking_engine: {}}) + + if roop.globals.selected_enhancer == 'GFPGAN': + processors.update({"gfpgan": {}}) + elif roop.globals.selected_enhancer == 'Codeformer': + processors.update({"codeformer": {}}) + elif roop.globals.selected_enhancer == 'DMDNet': + processors.update({"dmdnet": {}}) + elif roop.globals.selected_enhancer == 'GPEN': + processors.update({"gpen": {}}) + elif roop.globals.selected_enhancer == 'Restoreformer++': + processors.update({"restoreformer++": {}}) + return processors + + +def live_swap(frame, options): + global process_mgr + + if frame is None: + return frame + + if process_mgr is None: + process_mgr = ProcessMgr(None) + + process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options) + newframe = process_mgr.process_frame(frame) + if newframe is None: + return frame + return newframe + + +def batch_process_regular(files: List[ProcessEntry], masking_engine: str, new_clip_text: str, use_new_method, imagemask, num_swap_steps, progress, selected_index=0) -> None: + global process_mgr + + release_resources() + limit_resources() + if process_mgr is None: + process_mgr = ProcessMgr(progress) + mask = imagemask["layers"][0] if imagemask is not None else None + if len(roop.globals.INPUT_FACESETS) <= selected_index: + selected_index = 0 + options = ProcessOptions(get_processing_plugins(masking_engine), roop.globals.distance_threshold, roop.globals.blend_ratio, roop.globals.face_swap_mode, selected_index, new_clip_text, mask, num_swap_steps, False) + process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options) + batch_process(files, use_new_method) + return + + +def batch_process_with_options(files: List[ProcessEntry], options, progress): + global process_mgr + + release_resources() + limit_resources() + if process_mgr is None: + process_mgr = ProcessMgr(progress) + process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options) + roop.globals.keep_frames = False + roop.globals.wait_after_extraction = False + roop.globals.skip_audio = False + batch_process(files, True) + + +def batch_process(files: List[ProcessEntry], use_new_method) -> None: + global process_mgr + + roop.globals.processing = True + + max_threads = suggest_execution_threads() + if max_threads == 1: + roop.globals.execution_threads = 1 + + imagefiles: List[ProcessEntry] = [] + videofiles: List[ProcessEntry] = [] + + update_status('Sorting videos/images') + + for index, f in enumerate(files): + fullname = f.filename + if util.has_image_extension(fullname): + destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'.{roop.globals.CFG.output_image_format}') + destination = util.replace_template(destination, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True) + f.finalname = destination + imagefiles.append(f) + + elif util.is_video(fullname) or util.has_extension(fullname, ['gif']): + destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'__temp.{roop.globals.CFG.output_video_format}') + f.finalname = destination + videofiles.append(f) + + if len(imagefiles) > 0: + update_status('Processing image(s)') + origimages = [] + fakeimages = [] + for f in imagefiles: + origimages.append(f.filename) + fakeimages.append(f.finalname) + + process_mgr.run_batch(origimages, fakeimages, roop.globals.execution_threads) + origimages.clear() + fakeimages.clear() + + if len(videofiles) > 0: + for index, v in enumerate(videofiles): + if not roop.globals.processing: + end_processing('Processing stopped!') + return + fps = v.fps if v.fps > 0 else util.detect_fps(v.filename) + if v.endframe == 0: + v.endframe = get_video_frame_total(v.filename) + + update_status(f'Creating {os.path.basename(v.finalname)} with {fps} FPS...') + start_processing = time() + if roop.globals.keep_frames or not use_new_method: + util.create_temp(v.filename) + update_status('Extracting frames...') + ffmpeg.extract_frames(v.filename, v.startframe, v.endframe, fps) + if not roop.globals.processing: + end_processing('Processing stopped!') + return + + temp_frame_paths = util.get_temp_frame_paths(v.filename) + process_mgr.run_batch(temp_frame_paths, temp_frame_paths, roop.globals.execution_threads) + if not roop.globals.processing: + end_processing('Processing stopped!') + return + if roop.globals.wait_after_extraction: + extract_path = os.path.dirname(temp_frame_paths[0]) + util.open_folder(extract_path) + input("Press any key to continue...") + print("Resorting frames to create video") + util.sort_rename_frames(extract_path) + + ffmpeg.create_video(v.filename, v.finalname, fps) + if not roop.globals.keep_frames: + util.delete_temp_frames(temp_frame_paths[0]) + else: + if util.has_extension(v.filename, ['gif']): + skip_audio = True + else: + skip_audio = roop.globals.skip_audio + process_mgr.run_batch_inmem(v.filename, v.finalname, v.startframe, v.endframe, fps, roop.globals.execution_threads, skip_audio) + + if not roop.globals.processing: + end_processing('Processing stopped!') + return + + video_file_name = v.finalname + if os.path.isfile(video_file_name): + destination = '' + if util.has_extension(v.filename, ['gif']): + gifname = util.get_destfilename_from_path(v.filename, roop.globals.output_path, '.gif') + destination = util.replace_template(gifname, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True) + + update_status('Creating final GIF') + ffmpeg.create_gif_from_video(video_file_name, destination) + if os.path.isfile(destination): + os.remove(video_file_name) + else: + skip_audio = roop.globals.skip_audio + destination = util.replace_template(video_file_name, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True) + + if not skip_audio: + ffmpeg.restore_audio(video_file_name, v.filename, v.startframe, v.endframe, destination) + if os.path.isfile(destination): + os.remove(video_file_name) + else: + shutil.move(video_file_name, destination) + update_status(f'\nProcessing {os.path.basename(destination)} took {time() - start_processing} secs') + + else: + update_status(f'Failed processing {os.path.basename(v.finalname)}!') + end_processing('Finished') + + +def end_processing(msg: str): + update_status(msg) + roop.globals.target_folder_path = None + release_resources() def destroy() -> None: if roop.globals.target_path: - clean_temp(roop.globals.target_path) - quit() + util.clean_temp(roop.globals.target_path) + release_resources() + sys.exit() def run() -> None: - parse_args() + args = parse_args() + + roop.globals.source_path = args.source_path + roop.globals.target_path = args.target_path + roop.globals.output_path = args.output_path + roop.globals.execution_providers = decode_execution_providers([args.execution_provider]) + roop.globals.max_memory = args.max_memory + roop.globals.distance_threshold = args.distance_threshold + roop.globals.blend_ratio = args.blend_ratio + roop.globals.face_swap_mode = args.face_swap_mode + roop.globals.CFG = Settings('config.yaml') + roop.globals.execution_threads = args.execution_threads + roop.globals.output_image_format = args.output_image_format + roop.globals.output_video_format = args.output_video_format + roop.globals.skip_audio = args.skip_audio + roop.globals.face_swap_mode == 'selected' + # Ensure these values are set + if not roop.globals.video_encoder: + roop.globals.video_encoder = 'libx264' # or another suitable default value + if not roop.globals.video_quality: + roop.globals.video_quality = 23 # or another suitable default value + + signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) + if not pre_check(): return - for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): - if not frame_processor.pre_check(): - return - limit_resources() - if roop.globals.headless: - start() \ No newline at end of file + + # Extract faces from the source and target files and create FaceSet objects + source_faces = extract_face_images(args.source_path, (False, 0)) + target_faces = extract_face_images(args.target_path, (False, util.has_image_extension(args.target_path))) + print("Number of targets faces is ", target_faces.count) + + if source_faces: + source_face_set = FaceSet() + for face_data in source_faces: + face = face_data[0] + face.mask_offsets = (0, 0, 0, 0, 1, 20) + source_face_set.faces.append(face) + if len(source_face_set.faces) > 1: + source_face_set.AverageEmbeddings() + roop.globals.INPUT_FACESETS.append(source_face_set) + + if target_faces: + target_face_set = FaceSet() + for face_data in target_faces: + face = face_data[0] + face.mask_offsets = (0, 0, 0, 0, 1, 20) + target_face_set.faces.append(face) + if len(target_face_set.faces) > 1: + target_face_set.AverageEmbeddings() + roop.globals.TARGET_FACES.append(target_face_set.faces[0]) # Assuming using the first face for target + + # Detect fps and endframe values for the source and target videos + source_fps = util.detect_fps(args.source_path) + source_endframe = get_video_frame_total(args.source_path) + target_fps = util.detect_fps(args.target_path) + target_endframe = get_video_frame_total(args.target_path) + + # Initialize ProcessEntry objects using detected values + source_entry = ProcessEntry( + filename=args.source_path, + start=0, + end=source_endframe, + fps=source_fps + ) + + target_entry = ProcessEntry( + filename=args.target_path, + start=0, + end=target_endframe, + fps=target_fps + ) + + files = [source_entry, target_entry] + batch_process_regular(files, None, None, False, None, 1, None) diff --git a/roop/face_util.py b/roop/face_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d870632d6d83cf3a007ae065f76a0ded8ea17732 --- /dev/null +++ b/roop/face_util.py @@ -0,0 +1,310 @@ +import threading +from typing import Any +import insightface + +import roop.globals +from roop.typing import Frame, Face + +import cv2 +import numpy as np +from skimage import transform as trans +from roop.capturer import get_video_frame +from roop.utilities import resolve_relative_path, conditional_download + +FACE_ANALYSER = None +THREAD_LOCK_ANALYSER = threading.Lock() +THREAD_LOCK_SWAPPER = threading.Lock() +FACE_SWAPPER = None + + +def get_face_analyser() -> Any: + global FACE_ANALYSER + + with THREAD_LOCK_ANALYSER: + if FACE_ANALYSER is None or roop.globals.g_current_face_analysis != roop.globals.g_desired_face_analysis: + model_path = resolve_relative_path('..') + # removed genderage + allowed_modules = roop.globals.g_desired_face_analysis + roop.globals.g_current_face_analysis = roop.globals.g_desired_face_analysis + if roop.globals.CFG.force_cpu: + print("Forcing CPU for Face Analysis") + FACE_ANALYSER = insightface.app.FaceAnalysis( + name="buffalo_l", + root=model_path, providers=["CPUExecutionProvider"],allowed_modules=allowed_modules + ) + else: + FACE_ANALYSER = insightface.app.FaceAnalysis( + name="buffalo_l", root=model_path, providers=roop.globals.execution_providers,allowed_modules=allowed_modules + ) + FACE_ANALYSER.prepare( + ctx_id=0, + det_size=(640, 640) if roop.globals.default_det_size else (320, 320), + ) + return FACE_ANALYSER + + +def get_first_face(frame: Frame) -> Any: + try: + faces = get_face_analyser().get(frame) + return min(faces, key=lambda x: x.bbox[0]) + # return sorted(faces, reverse=True, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[0] + except: + return None + + +def get_all_faces(frame: Frame) -> Any: + try: + faces = get_face_analyser().get(frame) + return sorted(faces, key=lambda x: x.bbox[0]) + except: + return None + + +def extract_face_images(source_filename, video_info, extra_padding=-1.0): + face_data = [] + source_image = None + + if video_info[0]: + frame = get_video_frame(source_filename, video_info[1]) + if frame is not None: + source_image = frame + else: + return face_data + else: + source_image = cv2.imdecode(np.fromfile(source_filename, dtype=np.uint8), cv2.IMREAD_COLOR) + + if source_image is None: + print("No source image!") + + faces = get_all_faces(source_image) + if faces is None: + print("NO faces here!") + return face_data + + i = 0 + for face in faces: + (startX, startY, endX, endY) = face["bbox"].astype("int") + startX, endX, startY, endY = clamp_cut_values(startX, endX, startY, endY, source_image) + if extra_padding > 0.0: + if source_image.shape[:2] == (512, 512): + i += 1 + face_data.append([face, source_image]) + continue + + found = False + for i in range(1, 3): + (startX, startY, endX, endY) = face["bbox"].astype("int") + startX, endX, startY, endY = clamp_cut_values(startX, endX, startY, endY, source_image) + cutout_padding = extra_padding + # top needs extra room for detection + padding = int((endY - startY) * cutout_padding) + oldY = startY + startY -= padding + + factor = 0.25 if i == 1 else 0.5 + cutout_padding = factor + padding = int((endY - oldY) * cutout_padding) + endY += padding + padding = int((endX - startX) * cutout_padding) + startX -= padding + endX += padding + startX, endX, startY, endY = clamp_cut_values( + startX, endX, startY, endY, source_image + ) + face_temp = source_image[startY:endY, startX:endX] + face_temp = resize_image_keep_content(face_temp) + testfaces = get_all_faces(face_temp) + if testfaces is not None and len(testfaces) > 0: + i += 1 + face_data.append([testfaces[0], face_temp]) + found = True + break + + if not found: + print("No face found after resizing, this shouldn't happen!") + continue + + face_temp = source_image[startY:endY, startX:endX] + if face_temp.size < 1: + continue + + i += 1 + face_data.append([face, face_temp]) + return face_data + + +def clamp_cut_values(startX, endX, startY, endY, image): + if startX < 0: + startX = 0 + if endX > image.shape[1]: + endX = image.shape[1] + if startY < 0: + startY = 0 + if endY > image.shape[0]: + endY = image.shape[0] + return startX, endX, startY, endY + + + +def face_offset_top(face: Face, offset): + face["bbox"][1] += offset + face["bbox"][3] += offset + lm106 = face.landmark_2d_106 + add = np.full_like(lm106, [0, offset]) + face["landmark_2d_106"] = lm106 + add + return face + + +def resize_image_keep_content(image, new_width=512, new_height=512): + dim = None + (h, w) = image.shape[:2] + if h > w: + r = new_height / float(h) + dim = (int(w * r), new_height) + else: + # Calculate the ratio of the width and construct the dimensions + r = new_width / float(w) + dim = (new_width, int(h * r)) + image = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) + (h, w) = image.shape[:2] + if h == new_height and w == new_width: + return image + resize_img = np.zeros(shape=(new_height, new_width, 3), dtype=image.dtype) + offs = (new_width - w) if h == new_height else (new_height - h) + startoffs = int(offs // 2) if offs % 2 == 0 else int(offs // 2) + 1 + offs = int(offs // 2) + + if h == new_height: + resize_img[0:new_height, startoffs : new_width - offs] = image + else: + resize_img[startoffs : new_height - offs, 0:new_width] = image + return resize_img + + +def rotate_image_90(image, rotate=True): + if rotate: + return np.rot90(image) + else: + return np.rot90(image, 1, (1, 0)) + + +def rotate_anticlockwise(frame): + return rotate_image_90(frame) + + +def rotate_clockwise(frame): + return rotate_image_90(frame, False) + + +def rotate_image_180(image): + return np.flip(image, 0) + + +# alignment code from insightface https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py + +arcface_dst = np.array( + [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], + ], + dtype=np.float32, +) + + +def estimate_norm(lmk, image_size=112, mode="arcface"): + assert lmk.shape == (5, 2) + assert image_size % 112 == 0 or image_size % 128 == 0 + if image_size % 112 == 0: + ratio = float(image_size) / 112.0 + diff_x = 0 + else: + ratio = float(image_size) / 128.0 + diff_x = 8.0 * ratio + dst = arcface_dst * ratio + dst[:, 0] += diff_x + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + + + +# aligned, M = norm_crop2(f[1], face.kps, 512) +def align_crop(img, landmark, image_size=112, mode="arcface"): + M = estimate_norm(landmark, image_size, mode) + warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) + return warped, M + + +def square_crop(im, S): + if im.shape[0] > im.shape[1]: + height = S + width = int(float(im.shape[1]) / im.shape[0] * S) + scale = float(S) / im.shape[0] + else: + width = S + height = int(float(im.shape[0]) / im.shape[1] * S) + scale = float(S) / im.shape[1] + resized_im = cv2.resize(im, (width, height)) + det_im = np.zeros((S, S, 3), dtype=np.uint8) + det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im + return det_im, scale + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + # print('new_pt', new_pt.shape, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +def trans_points3d(pts, M): + scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) + # print(scale) + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + # print('new_pt', new_pt.shape, new_pt) + new_pts[i][0:2] = new_pt[0:2] + new_pts[i][2] = pts[i][2] * scale + + return new_pts + + +def trans_points(pts, M): + if pts.shape[1] == 2: + return trans_points2d(pts, M) + else: + return trans_points3d(pts, M) + +def create_blank_image(width, height): + img = np.zeros((height, width, 4), dtype=np.uint8) + img[:] = [0,0,0,0] + return img + diff --git a/roop/ffmpeg_writer.py b/roop/ffmpeg_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..9642efad2de4e2b3463a62d1ee04b5f02402702c --- /dev/null +++ b/roop/ffmpeg_writer.py @@ -0,0 +1,218 @@ +""" +FFMPEG_Writer - write set of frames to video file + +original from +https://github.com/Zulko/moviepy/blob/master/moviepy/video/io/ffmpeg_writer.py + +removed unnecessary dependencies + +The MIT License (MIT) + +Copyright (c) 2015 Zulko +Copyright (c) 2023 Janvarev Vladislav +""" + +import os +import subprocess as sp + +PIPE = -1 +STDOUT = -2 +DEVNULL = -3 + +FFMPEG_BINARY = "ffmpeg" + +class FFMPEG_VideoWriter: + """ A class for FFMPEG-based video writing. + + A class to write videos using ffmpeg. ffmpeg will write in a large + choice of formats. + + Parameters + ----------- + + filename + Any filename like 'video.mp4' etc. but if you want to avoid + complications it is recommended to use the generic extension + '.avi' for all your videos. + + size + Size (width,height) of the output video in pixels. + + fps + Frames per second in the output video file. + + codec + FFMPEG codec. It seems that in terms of quality the hierarchy is + 'rawvideo' = 'png' > 'mpeg4' > 'libx264' + 'png' manages the same lossless quality as 'rawvideo' but yields + smaller files. Type ``ffmpeg -codecs`` in a terminal to get a list + of accepted codecs. + + Note for default 'libx264': by default the pixel format yuv420p + is used. If the video dimensions are not both even (e.g. 720x405) + another pixel format is used, and this can cause problem in some + video readers. + + audiofile + Optional: The name of an audio file that will be incorporated + to the video. + + preset + Sets the time that FFMPEG will take to compress the video. The slower, + the better the compression rate. Possibilities are: ultrafast,superfast, + veryfast, faster, fast, medium (default), slow, slower, veryslow, + placebo. + + bitrate + Only relevant for codecs which accept a bitrate. "5000k" offers + nice results in general. + + """ + + def __init__(self, filename, size, fps, codec="libx265", crf=14, audiofile=None, + preset="medium", bitrate=None, + logfile=None, threads=None, ffmpeg_params=None): + + if logfile is None: + logfile = sp.PIPE + + self.filename = filename + self.codec = codec + self.ext = self.filename.split(".")[-1] + w = size[0] - 1 if size[0] % 2 != 0 else size[0] + h = size[1] - 1 if size[1] % 2 != 0 else size[1] + + + # order is important + cmd = [ + FFMPEG_BINARY, + '-hide_banner', + '-hwaccel', 'auto', + '-y', + '-loglevel', 'error' if logfile == sp.PIPE else 'info', + '-f', 'rawvideo', + '-vcodec', 'rawvideo', + '-s', '%dx%d' % (size[0], size[1]), + #'-pix_fmt', 'rgba' if withmask else 'rgb24', + '-pix_fmt', 'bgr24', + '-r', str(fps), + '-an', '-i', '-' + ] + + if audiofile is not None: + cmd.extend([ + '-i', audiofile, + '-acodec', 'copy' + ]) + + cmd.extend([ + '-vcodec', codec, + '-crf', str(crf) + #'-preset', preset, + ]) + if ffmpeg_params is not None: + cmd.extend(ffmpeg_params) + if bitrate is not None: + cmd.extend([ + '-b', bitrate + ]) + + # scale to a resolution divisible by 2 if not even + cmd.extend(['-vf', f'scale={w}:{h}' if w != size[0] or h != size[1] else 'colorspace=bt709:iall=bt601-6-625:fast=1']) + + if threads is not None: + cmd.extend(["-threads", str(threads)]) + + cmd.extend([ + '-pix_fmt', 'yuv420p', + + ]) + cmd.extend([ + filename + ]) + + test = str(cmd) + print(test) + + popen_params = {"stdout": DEVNULL, + "stderr": logfile, + "stdin": sp.PIPE} + + # This was added so that no extra unwanted window opens on windows + # when the child process is created + if os.name == "nt": + popen_params["creationflags"] = 0x08000000 # CREATE_NO_WINDOW + + self.proc = sp.Popen(cmd, **popen_params) + + + def write_frame(self, img_array): + """ Writes one frame in the file.""" + try: + #if PY3: + self.proc.stdin.write(img_array.tobytes()) + # else: + # self.proc.stdin.write(img_array.tostring()) + except IOError as err: + _, ffmpeg_error = self.proc.communicate() + error = (str(err) + ("\n\nroop unleashed error: FFMPEG encountered " + "the following error while writing file %s:" + "\n\n %s" % (self.filename, str(ffmpeg_error)))) + + if b"Unknown encoder" in ffmpeg_error: + + error = error+("\n\nThe video export " + "failed because FFMPEG didn't find the specified " + "codec for video encoding (%s). Please install " + "this codec or change the codec when calling " + "write_videofile. For instance:\n" + " >>> clip.write_videofile('myvid.webm', codec='libvpx')")%(self.codec) + + elif b"incorrect codec parameters ?" in ffmpeg_error: + + error = error+("\n\nThe video export " + "failed, possibly because the codec specified for " + "the video (%s) is not compatible with the given " + "extension (%s). Please specify a valid 'codec' " + "argument in write_videofile. This would be 'libx264' " + "or 'mpeg4' for mp4, 'libtheora' for ogv, 'libvpx for webm. " + "Another possible reason is that the audio codec was not " + "compatible with the video codec. For instance the video " + "extensions 'ogv' and 'webm' only allow 'libvorbis' (default) as a" + "video codec." + )%(self.codec, self.ext) + + elif b"encoder setup failed" in ffmpeg_error: + + error = error+("\n\nThe video export " + "failed, possibly because the bitrate you specified " + "was too high or too low for the video codec.") + + elif b"Invalid encoder type" in ffmpeg_error: + + error = error + ("\n\nThe video export failed because the codec " + "or file extension you provided is not a video") + + + raise IOError(error) + + def close(self): + if self.proc: + self.proc.stdin.close() + if self.proc.stderr is not None: + self.proc.stderr.close() + self.proc.wait() + + self.proc = None + + # Support the Context Manager protocol, to ensure that resources are cleaned up. + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + + + diff --git a/roop/globals.py b/roop/globals.py index 77fd391db235b878ce1f91765596bd76adb06697..b1228e3d0652def3c164b332120f8ad2d20292af 100644 --- a/roop/globals.py +++ b/roop/globals.py @@ -1,13 +1,23 @@ +from settings import Settings from typing import List source_path = None target_path = None output_path = None +target_folder_path = None + frame_processors: List[str] = [] keep_fps = None -keep_audio = None keep_frames = None +autorotate_faces = None +vr_mode = None +skip_audio = None +wait_after_extraction = None many_faces = None +use_batch = None +source_face_index = 0 +target_face_index = 0 +face_position = None video_encoder = None video_quality = None max_memory = None @@ -15,3 +25,29 @@ execution_providers: List[str] = [] execution_threads = None headless = None log_level = 'error' +selected_enhancer = None +face_swap_mode = None +blend_ratio = 0.5 +distance_threshold = 0.65 +default_det_size = True + +no_face_action = 0 + +processing = False + +g_current_face_analysis = None +g_desired_face_analysis = None + +FACE_ENHANCER = None + +INPUT_FACESETS = [] +TARGET_FACES = [] + + +IMAGE_CHAIN_PROCESSOR = None +VIDEO_CHAIN_PROCESSOR = None +BATCH_IMAGE_CHAIN_PROCESSOR = None + +CFG: Settings = None + + diff --git a/roop/metadata.py b/roop/metadata.py index 35b0f0245a38eb9ec024f2ed2c829044f6051c29..469e3990c42b6a278b1d7941bdc4dac53f28c72e 100644 --- a/roop/metadata.py +++ b/roop/metadata.py @@ -1,2 +1,2 @@ -name = 'roop' -version = '1.1.0' +name = 'roop unleashed' +version = '4.0.0' diff --git a/roop/processors/Enhance_CodeFormer.py b/roop/processors/Enhance_CodeFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d00a3d431f6b16a659d5722314b3531a6af425d --- /dev/null +++ b/roop/processors/Enhance_CodeFormer.py @@ -0,0 +1,75 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +# THREAD_LOCK = threading.Lock() + + +class Enhance_CodeFormer(): + model_codeformer = None + + plugin_options:dict = None + + processorname = 'codeformer' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_codeformer is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + model_path = resolve_relative_path('../models/CodeFormer/CodeFormerv0.1.onnx') + self.model_codeformer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_codeformer.get_inputs() + model_outputs = self.model_codeformer.get_outputs() + self.io_binding = self.model_codeformer.io_binding() + self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5])) + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + input_size = temp_frame.shape[1] + # preprocess + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame.astype(np.float32)) + self.model_codeformer.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + + # post-process + result = result.transpose((1, 2, 0)) + + un_min = -1.0 + un_max = 1.0 + result = np.clip(result, un_min, un_max) + result = (result - un_min) / (un_max - un_min) + + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + result = (result * 255.0).round() + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + del self.model_codeformer + self.model_codeformer = None + del self.io_binding + self.io_binding = None + diff --git a/roop/processors/Enhance_DMDNet.py b/roop/processors/Enhance_DMDNet.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6a6bb2d2fdad863dcbf66da8e498555d357a64 --- /dev/null +++ b/roop/processors/Enhance_DMDNet.py @@ -0,0 +1,898 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.spectral_norm as SpectralNorm +import threading +from torchvision.ops import roi_align + +from math import sqrt + +from torchvision.transforms.functional import normalize + +from roop.typing import Face, Frame, FaceSet + + +THREAD_LOCK_DMDNET = threading.Lock() + + +class Enhance_DMDNet(): + plugin_options:dict = None + model_dmdnet = None + torchdevice = None + + processorname = 'dmdnet' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_dmdnet is None: + self.model_dmdnet = self.create(self.plugin_options["devicename"]) + + + # temp_frame already cropped+aligned, bbox not + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + input_size = temp_frame.shape[1] + + result = self.enhance_face(source_faceset, temp_frame, target_face) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + self.model_gfpgan = None + + + # https://stackoverflow.com/a/67174339 + def landmarks106_to_68(self, pt106): + map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17, + 43,48,49,51,50, + 102,103,104,105,101, + 72,73,74,86,78,79,80,85,84, + 35,41,42,39,37,36, + 89,95,96,93,91,90, + 52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54 + ] + + pt68 = [] + for i in range(68): + index = map106to68[i] + pt68.append(pt106[index]) + return pt68 + + + + + def check_bbox(self, imgs, boxes): + boxes = boxes.view(-1, 4, 4) + colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)] + i = 0 + for img, box in zip(imgs, boxes): + img = (img + 1)/2 * 255 + img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy() + for idx, point in enumerate(box): + cv2.rectangle(img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2) + cv2.imwrite('dmdnet_{:02d}.png'.format(i), img2) + i += 1 + + + def trans_points2d(self, pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + + def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face): + # preprocess + start_x, start_y, end_x, end_y = map(int, face['bbox']) + lm106 = face.landmark_2d_106 + lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) + + if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512: + # scale to 512x512 + scale_factor = 512 / temp_frame.shape[1] + + M = face.matrix * scale_factor + + lq_landmarks = self.trans_points2d(lq_landmarks, M) + temp_frame = cv2.resize(temp_frame, (512,512), interpolation = cv2.INTER_AREA) + + if temp_frame.ndim == 2: + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG + # else: + # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB + + lq = read_img_tensor(temp_frame) + + LQLocs = get_component_location(lq_landmarks) + # self.check_bbox(lq, LQLocs.unsqueeze(0)) + + # specific, change 1000 to 1 to activate + if len(ref_faceset.faces) > 1: + SpecificImgs = [] + SpecificLocs = [] + for i,face in enumerate(ref_faceset.faces): + lm106 = face.landmark_2d_106 + lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) + ref_image = ref_faceset.ref_images[i] + if ref_image.shape[0] != 512 or ref_image.shape[1] != 512: + # scale to 512x512 + scale_factor = 512 / ref_image.shape[1] + + M = face.matrix * scale_factor + + lq_landmarks = self.trans_points2d(lq_landmarks, M) + ref_image = cv2.resize(ref_image, (512,512), interpolation = cv2.INTER_AREA) + + if ref_image.ndim == 2: + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG + # else: + # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB + + ref_tensor = read_img_tensor(ref_image) + ref_locs = get_component_location(lq_landmarks) + # self.check_bbox(ref_tensor, ref_locs.unsqueeze(0)) + + SpecificImgs.append(ref_tensor) + SpecificLocs.append(ref_locs.unsqueeze(0)) + + SpecificImgs = torch.cat(SpecificImgs, dim=0) + SpecificLocs = torch.cat(SpecificLocs, dim=0) + # check_bbox(SpecificImgs, SpecificLocs) + SpMem256, SpMem128, SpMem64 = self.model_dmdnet.generate_specific_dictionary(sp_imgs = SpecificImgs.to(self.torchdevice), sp_locs = SpecificLocs) + SpMem256Para = {} + SpMem128Para = {} + SpMem64Para = {} + for k, v in SpMem256.items(): + SpMem256Para[k] = v + for k, v in SpMem128.items(): + SpMem128Para[k] = v + for k, v in SpMem64.items(): + SpMem64Para[k] = v + else: + # generic + SpMem256Para, SpMem128Para, SpMem64Para = None, None, None + + with torch.no_grad(): + with THREAD_LOCK_DMDNET: + try: + GenericResult, SpecificResult = self.model_dmdnet(lq = lq.to(self.torchdevice), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para) + except Exception as e: + print(f'Error {e} there may be something wrong with the detected component locations.') + return temp_frame + + if SpecificResult is not None: + save_specific = SpecificResult * 0.5 + 0.5 + save_specific = save_specific.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0 + temp_frame = save_specific.astype("uint8") + if False: + save_generic = GenericResult * 0.5 + 0.5 + save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 + check_lq = lq * 0.5 + 0.5 + check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0 + cv2.imwrite('dmdnet_comparison.png', cv2.cvtColor(np.hstack((check_lq, save_generic, save_specific)),cv2.COLOR_RGB2BGR)) + else: + save_generic = GenericResult * 0.5 + 0.5 + save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 + temp_frame = save_generic.astype("uint8") + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB + return temp_frame + + + + def create(self, devicename): + self.torchdevice = torch.device(devicename) + model_dmdnet = DMDNet().to(self.torchdevice) + weights = torch.load('./models/DMDNet.pth') + model_dmdnet.load_state_dict(weights, strict=True) + + model_dmdnet.eval() + num_params = 0 + for param in model_dmdnet.parameters(): + num_params += param.numel() + return model_dmdnet + + # print('{:>8s} : {}'.format('Using device', device)) + # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6)) + + + +def read_img_tensor(Img=None): #rgb -1~1 + Img = Img.transpose((2, 0, 1))/255.0 + Img = torch.from_numpy(Img).float() + normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True) + ImgTensor = Img.unsqueeze(0) + return ImgTensor + + +def get_component_location(Landmarks, re_read=False): + if re_read: + ReadLandmark = [] + with open(Landmarks,'r') as f: + for line in f: + tmp = [float(i) for i in line.split(' ') if i != '\n'] + ReadLandmark.append(tmp) + ReadLandmark = np.array(ReadLandmark) # + Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2 + Map_LE_B = list(np.hstack((range(17,22), range(36,42)))) + Map_RE_B = list(np.hstack((range(22,27), range(42,48)))) + Map_LE = list(range(36,42)) + Map_RE = list(range(42,48)) + Map_NO = list(range(29,36)) + Map_MO = list(range(48,68)) + + Landmarks[Landmarks>504]=504 + Landmarks[Landmarks<8]=8 + + #left eye + Mean_LE = np.mean(Landmarks[Map_LE],0) + L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1]) + L_LE1 = L_LE1 * 1.3 + L_LE2 = L_LE1 / 1.9 + L_LE_xy = L_LE1 + L_LE2 + L_LE_lt = [L_LE_xy/2, L_LE1] + L_LE_rb = [L_LE_xy/2, L_LE2] + Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int) + + #right eye + Mean_RE = np.mean(Landmarks[Map_RE],0) + L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1]) + L_RE1 = L_RE1 * 1.3 + L_RE2 = L_RE1 / 1.9 + L_RE_xy = L_RE1 + L_RE2 + L_RE_lt = [L_RE_xy/2, L_RE1] + L_RE_rb = [L_RE_xy/2, L_RE2] + Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int) + + #nose + Mean_NO = np.mean(Landmarks[Map_NO],0) + L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25 + L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1 + L_NO_xy = L_NO1 * 2 + L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2] + L_NO_rb = [L_NO_xy/2, L_NO2] + Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int) + + #mouth + Mean_MO = np.mean(Landmarks[Map_MO],0) + L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1 + MO_O = Mean_MO - L_MO + 1 + MO_T = Mean_MO + L_MO + MO_T[MO_T>510]=510 + Location_MO = np.hstack((MO_O, MO_T)).astype(int) + return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0) + + + + +def calc_mean_std_4D(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature + size = content_feat.size() + style_mean, style_std = calc_mean_std_4D(style_feat) + content_mean, content_std = calc_mean_std_4D(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True): + return nn.Sequential( + SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), + nn.LeakyReLU(0.2), + SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), + ) + + +class MSDilateBlock(nn.Module): + def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True): + super(MSDilateBlock, self).__init__() + self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias) + self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias) + self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias) + self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias) + self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias)) + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(x) + conv3 = self.conv3(x) + conv4 = self.conv4(x) + cat = torch.cat([conv1, conv2, conv3, conv4], 1) + out = self.convi(cat) + x + return out + + +class AdaptiveInstanceNorm(nn.Module): + def __init__(self, in_channel): + super().__init__() + self.norm = nn.InstanceNorm2d(in_channel) + + def forward(self, input, style): + style_mean, style_std = calc_mean_std_4D(style) + out = self.norm(input) + size = input.size() + out = style_std.expand(size) * out + style_mean.expand(size) + return out + +class NoiseInjection(nn.Module): + def __init__(self, channel): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) + def forward(self, image, noise): + if noise is None: + b, c, h, w = image.shape + noise = image.new_empty(b, 1, h, w).normal_() + return image + self.weight * noise + +class StyledUpBlock(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False): + super().__init__() + + self.noise_inject = noise_inject + if upsample: + self.conv1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + ) + else: + self.conv1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + ) + self.convup = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + ) + if self.noise_inject: + self.noise1 = NoiseInjection(out_channel) + + self.lrelu1 = nn.LeakyReLU(0.2) + + self.ScaleModel1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)) + ) + self.ShiftModel1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), + ) + + def forward(self, input, style): + out = self.conv1(input) + out = self.lrelu1(out) + Shift1 = self.ShiftModel1(style) + Scale1 = self.ScaleModel1(style) + out = out * Scale1 + Shift1 + if self.noise_inject: + out = self.noise1(out, noise=None) + outup = self.convup(out) + return outup + + +#################################################################### +###############Face Dictionary Generator +#################################################################### +def AttentionBlock(in_channel): + return nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), + ) + +class DilateResBlock(nn.Module): + def __init__(self, dim, dilation=[5,3] ): + super(DilateResBlock, self).__init__() + self.Res = nn.Sequential( + SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])), + ) + def forward(self, x): + out = x + self.Res(x) + return out + + +class KeyValue(nn.Module): + def __init__(self, indim, keydim, valdim): + super(KeyValue, self).__init__() + self.Key = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.Value = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + def forward(self, x): + return self.Key(x), self.Value(x) + +class MaskAttention(nn.Module): + def __init__(self, indim): + super(MaskAttention, self).__init__() + self.conv1 = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.conv2 = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.conv3 = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + self.convCat = nn.Sequential( + SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + def forward(self, x, y, z): + c1 = self.conv1(x) + c2 = self.conv2(y) + c3 = self.conv3(z) + return self.convCat(torch.cat([c1,c2,c3], dim=1)) + +class Query(nn.Module): + def __init__(self, indim, quedim): + super(Query, self).__init__() + self.Query = nn.Sequential( + SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)), + ) + def forward(self, x): + return self.Query(x) + +def roi_align_self(input, location, target_size): + test = (target_size.item(),target_size.item()) + return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],test,mode='bilinear',align_corners=False) for i in range(input.size(0))],0) + +class FeatureExtractor(nn.Module): + def __init__(self, ngf = 64, key_scale = 4):# + super().__init__() + + self.key_scale = 4 + self.part_sizes = np.array([80,80,50,110]) # + self.feature_sizes = np.array([256,128,64]) # + + self.conv1 = nn.Sequential( + SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + ) + self.conv2 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)) + ) + self.res1 = DilateResBlock(ngf, [5,3]) + self.res2 = DilateResBlock(ngf, [5,3]) + + + self.conv3 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)), + ) + self.conv4 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)) + ) + self.res3 = DilateResBlock(ngf*2, [3,1]) + self.res4 = DilateResBlock(ngf*2, [3,1]) + + self.conv5 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)), + ) + self.conv6 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)) + ) + self.res5 = DilateResBlock(ngf*4, [1,1]) + self.res6 = DilateResBlock(ngf*4, [1,1]) + + self.LE_256_Q = Query(ngf, ngf // self.key_scale) + self.RE_256_Q = Query(ngf, ngf // self.key_scale) + self.MO_256_Q = Query(ngf, ngf // self.key_scale) + self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + + + def forward(self, img, locs): + le_location = locs[:,0,:].int().cpu().numpy() + re_location = locs[:,1,:].int().cpu().numpy() + no_location = locs[:,2,:].int().cpu().numpy() + mo_location = locs[:,3,:].int().cpu().numpy() + + + f1_0 = self.conv1(img) + f1_1 = self.res1(f1_0) + f2_0 = self.conv2(f1_1) + f2_1 = self.res2(f2_0) + + f3_0 = self.conv3(f2_1) + f3_1 = self.res3(f3_0) + f4_0 = self.conv4(f3_1) + f4_1 = self.res4(f4_0) + + f5_0 = self.conv5(f4_1) + f5_1 = self.res5(f5_0) + f6_0 = self.conv6(f5_1) + f6_1 = self.res6(f6_0) + + + ####ROI Align + le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2) + re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2) + mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2) + + le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4) + re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4) + mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4) + + le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8) + re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8) + mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8) + + + le_256_q = self.LE_256_Q(le_part_256) + re_256_q = self.RE_256_Q(re_part_256) + mo_256_q = self.MO_256_Q(mo_part_256) + + le_128_q = self.LE_128_Q(le_part_128) + re_128_q = self.RE_128_Q(re_part_128) + mo_128_q = self.MO_128_Q(mo_part_128) + + le_64_q = self.LE_64_Q(le_part_64) + re_64_q = self.RE_64_Q(re_part_64) + mo_64_q = self.MO_64_Q(mo_part_64) + + return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\ + 'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \ + 'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \ + 'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \ + 'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\ + 'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\ + 'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q} + + +class DMDNet(nn.Module): + def __init__(self, ngf = 64, banks_num = 128): + super().__init__() + self.part_sizes = np.array([80,80,50,110]) # size for 512 + self.feature_sizes = np.array([256,128,64]) # size for 512 + + self.banks_num = banks_num + self.key_scale = 4 + + self.E_lq = FeatureExtractor(key_scale = self.key_scale) + self.E_hq = FeatureExtractor(key_scale = self.key_scale) + + self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + + self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) + self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) + self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) + + self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) + self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) + self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) + + + self.LE_256_Attention = AttentionBlock(64) + self.RE_256_Attention = AttentionBlock(64) + self.MO_256_Attention = AttentionBlock(64) + + self.LE_128_Attention = AttentionBlock(128) + self.RE_128_Attention = AttentionBlock(128) + self.MO_128_Attention = AttentionBlock(128) + + self.LE_64_Attention = AttentionBlock(256) + self.RE_64_Attention = AttentionBlock(256) + self.MO_64_Attention = AttentionBlock(256) + + self.LE_256_Mask = MaskAttention(64) + self.RE_256_Mask = MaskAttention(64) + self.MO_256_Mask = MaskAttention(64) + + self.LE_128_Mask = MaskAttention(128) + self.RE_128_Mask = MaskAttention(128) + self.MO_128_Mask = MaskAttention(128) + + self.LE_64_Mask = MaskAttention(256) + self.RE_64_Mask = MaskAttention(256) + self.MO_64_Mask = MaskAttention(256) + + self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1]) + + self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) # + self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) # + self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) # + self.up4 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + nn.LeakyReLU(0.2), + UpResBlock(ngf), + UpResBlock(ngf), + SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)), + nn.Tanh() + ) + + # define generic memory, revise register_buffer to register_parameter for backward update + self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40)) + self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40)) + self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55)) + self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40)) + self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40)) + self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55)) + + + self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20)) + self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20)) + self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27)) + self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20)) + self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20)) + self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27)) + + self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10)) + self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10)) + self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13)) + self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10)) + self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10)) + self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13)) + + + def readMem(self, k, v, q): + sim = F.conv2d(q, k) + score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128 + sb,sn,sw,sh = score.size() + s_m = score.view(sb, -1).unsqueeze(1)#2*1*M + vb,vn,vw,vh = v.size() + v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h) + mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh) + max_inds = torch.argmax(score, dim=1).squeeze() + return mem_out, max_inds + + + def memorize(self, img, locs): + fs = self.E_hq(img, locs) + LE256_key, LE256_value = self.LE_256_KV(fs['le256']) + RE256_key, RE256_value = self.RE_256_KV(fs['re256']) + MO256_key, MO256_value = self.MO_256_KV(fs['mo256']) + + LE128_key, LE128_value = self.LE_128_KV(fs['le128']) + RE128_key, RE128_value = self.RE_128_KV(fs['re128']) + MO128_key, MO128_value = self.MO_128_KV(fs['mo128']) + + LE64_key, LE64_value = self.LE_64_KV(fs['le64']) + RE64_key, RE64_value = self.RE_64_KV(fs['re64']) + MO64_key, MO64_value = self.MO_64_KV(fs['mo64']) + + Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value} + Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value} + Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value} + + FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']} + FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']} + FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']} + + return Mem256, Mem128, Mem64 + + def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None): + le_256_q = fs_in['le_256_q'] + re_256_q = fs_in['re_256_q'] + mo_256_q = fs_in['mo_256_q'] + + le_128_q = fs_in['le_128_q'] + re_128_q = fs_in['re_128_q'] + mo_128_q = fs_in['mo_128_q'] + + le_64_q = fs_in['le_64_q'] + re_64_q = fs_in['re_64_q'] + mo_64_q = fs_in['mo_64_q'] + + + ####for 256 + le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q) + re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q) + mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q) + + le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q) + re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q) + mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q) + + le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q) + re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q) + mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q) + + if sp_256 is not None and sp_128 is not None and sp_64 is not None: + le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q) + re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q) + mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q) + le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g) + le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g + re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g) + re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g + mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g) + mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g + + le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q) + re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q) + mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q) + le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g) + le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g + re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g) + re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g + mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g) + mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g + + le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q) + re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q) + mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q) + le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g) + le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g + re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g) + re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g + mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g) + mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g + else: + le_256_mem = le_256_mem_g + re_256_mem = re_256_mem_g + mo_256_mem = mo_256_mem_g + le_128_mem = le_128_mem_g + re_128_mem = re_128_mem_g + mo_128_mem = mo_128_mem_g + le_64_mem = le_64_mem_g + re_64_mem = re_64_mem_g + mo_64_mem = mo_64_mem_g + + le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256']) + re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256']) + mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256']) + + ####for 128 + le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128']) + re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128']) + mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128']) + + ####for 64 + le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64']) + re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64']) + mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64']) + + + EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm} + EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm} + EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm} + Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds} + Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds} + Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds} + return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64 + + def reconstruct(self, fs_in, locs, memstar): + le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm'] + le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm'] + le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm'] + + le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256'] + re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256'] + mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256'] + + le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128'] + re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128'] + mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128'] + + le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64'] + re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64'] + mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64'] + + + le_location = locs[:,0,:] + re_location = locs[:,1,:] + mo_location = locs[:,3,:] + + # Somehow with latest Torch it doesn't like numpy wrappers anymore + + # le_location = le_location.cpu().int().numpy() + # re_location = re_location.cpu().int().numpy() + # mo_location = mo_location.cpu().int().numpy() + le_location = le_location.cpu().int() + re_location = re_location.cpu().int() + mo_location = mo_location.cpu().int() + + up_in_256 = fs_in['f256'].clone()# * 0 + up_in_128 = fs_in['f128'].clone()# * 0 + up_in_64 = fs_in['f64'].clone()# * 0 + + for i in range(fs_in['f256'].size(0)): + up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False) + up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False) + up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False) + + up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False) + up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False) + up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False) + + up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False) + up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False) + up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False) + + ms_in_64 = self.MSDilate(fs_in['f64'].clone()) + fea_up1 = self.up1(ms_in_64, up_in_64) + fea_up2 = self.up2(fea_up1, up_in_128) # + fea_up3 = self.up3(fea_up2, up_in_256) # + output = self.up4(fea_up3) # + return output + + def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None): + return self.memorize(sp_imgs, sp_locs) + + def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None): + try: + fs_in = self.E_lq(lq, loc) # low quality images + except Exception as e: + print(e) + + GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in) + GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64]) + if sp_256 is not None and sp_128 is not None and sp_64 is not None: + GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64) + GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64]) + else: + GSOut = None + return GeOut, GSOut + +class UpResBlock(nn.Module): + def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d): + super(UpResBlock, self).__init__() + self.Model = nn.Sequential( + SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + ) + def forward(self, x): + out = x + self.Model(x) + return out diff --git a/roop/processors/Enhance_GFPGAN.py b/roop/processors/Enhance_GFPGAN.py new file mode 100644 index 0000000000000000000000000000000000000000..ca61cb70f302712ca1c6f54ee06aad9ed0f33f0c --- /dev/null +++ b/roop/processors/Enhance_GFPGAN.py @@ -0,0 +1,77 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +# THREAD_LOCK = threading.Lock() + + +class Enhance_GFPGAN(): + plugin_options:dict = None + + model_gfpgan = None + name = None + devicename = None + + processorname = 'gfpgan' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_gfpgan is None: + model_path = resolve_relative_path('../models/GFPGANv1.4.onnx') + self.model_gfpgan = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + + self.name = self.model_gfpgan.get_inputs()[0].name + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + io_binding = self.model_gfpgan.io_binding() + io_binding.bind_cpu_input("input", temp_frame) + io_binding.bind_output("1288", self.devicename) + self.model_gfpgan.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + + # post-process + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + self.model_gfpgan = None + + + + + + + + + + + diff --git a/roop/processors/Enhance_GPEN.py b/roop/processors/Enhance_GPEN.py new file mode 100644 index 0000000000000000000000000000000000000000..9821e70534e3bddcd2a932548fd7b9250d85a41a --- /dev/null +++ b/roop/processors/Enhance_GPEN.py @@ -0,0 +1,63 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +class Enhance_GPEN(): + plugin_options:dict = None + + model_gpen = None + name = None + devicename = None + + processorname = 'gpen' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_gpen is None: + model_path = resolve_relative_path('../models/GPEN-BFR-512.onnx') + self.model_gpen = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + + self.name = self.model_gpen.get_inputs()[0].name + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + io_binding = self.model_gpen.io_binding() + io_binding.bind_cpu_input("input", temp_frame) + io_binding.bind_output("output", self.devicename) + self.model_gpen.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + + # post-process + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + self.model_gpen = None diff --git a/roop/processors/Enhance_RestoreFormerPPlus.py b/roop/processors/Enhance_RestoreFormerPPlus.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d71034573cf1e63be77a4b9acafc854f189536 --- /dev/null +++ b/roop/processors/Enhance_RestoreFormerPPlus.py @@ -0,0 +1,64 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + +class Enhance_RestoreFormerPPlus(): + plugin_options:dict = None + model_restoreformerpplus = None + devicename = None + name = None + + processorname = 'restoreformer++' + type = 'enhance' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_restoreformerpplus is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + model_path = resolve_relative_path('../models/restoreformer_plus_plus.onnx') + self.model_restoreformerpplus = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_restoreformerpplus.get_inputs() + model_outputs = self.model_restoreformerpplus.get_outputs() + self.io_binding = self.model_restoreformerpplus.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) # .astype(np.float32) + self.model_restoreformerpplus.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + + def Release(self): + del self.model_restoreformerpplus + self.model_restoreformerpplus = None + del self.io_binding + self.io_binding = None + diff --git a/roop/processors/FaceSwapInsightFace.py b/roop/processors/FaceSwapInsightFace.py new file mode 100644 index 0000000000000000000000000000000000000000..34290899fed8f74b4e7bc7aaf2909779dfb4d639 --- /dev/null +++ b/roop/processors/FaceSwapInsightFace.py @@ -0,0 +1,69 @@ +import roop.globals +import cv2 +import numpy as np +import onnx +import onnxruntime + +from roop.typing import Face, Frame +from roop.utilities import resolve_relative_path + + + +class FaceSwapInsightFace(): + plugin_options:dict = None + model_swap_insightface = None + + processorname = 'faceswap' + type = 'swap' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_swap_insightface is None: + model_path = resolve_relative_path('../models/inswapper_128.onnx') + graph = onnx.load(model_path).graph + self.emap = onnx.numpy_helper.to_array(graph.initializer[-1]) + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + self.input_mean = 0.0 + self.input_std = 255.0 + #cuda_options = {"arena_extend_strategy": "kSameAsRequested", 'cudnn_conv_algo_search': 'DEFAULT'} + sess_options = onnxruntime.SessionOptions() + sess_options.enable_cpu_mem_arena = False + self.model_swap_insightface = onnxruntime.InferenceSession(model_path, sess_options, providers=roop.globals.execution_providers) + + + + def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: + blob = cv2.dnn.blobFromImage(temp_frame, 1.0 / self.input_std, (128, 128), + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + latent = source_face.normed_embedding.reshape((1,-1)) + latent = np.dot(latent, self.emap) + latent /= np.linalg.norm(latent) + io_binding = self.model_swap_insightface.io_binding() + io_binding.bind_cpu_input("target", blob) + io_binding.bind_cpu_input("source", latent) + io_binding.bind_output("output", self.devicename) + self.model_swap_insightface.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu()[0] + img_fake = ort_outs.transpose((0,2,3,1))[0] + return np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] + + + img_fake, M = self.model_swap_insightface.get(temp_frame, target_face, source_face, paste_back=False) + # target_face.matrix = M + # return img_fake + + + def Release(self): + del self.model_swap_insightface + self.model_swap_insightface = None + + + + + + diff --git a/roop/processors/Frame_Colorizer.py b/roop/processors/Frame_Colorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..372f81870b6c47f543707e8eefff3a474532b493 --- /dev/null +++ b/roop/processors/Frame_Colorizer.py @@ -0,0 +1,70 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + +class Frame_Colorizer(): + plugin_options:dict = None + model_colorizer = None + devicename = None + prev_type = None + + processorname = 'deoldify' + type = 'frame_colorizer' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]: + self.Release() + self.prev_type = self.plugin_options["subtype"] + if self.model_colorizer is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + if self.prev_type == "deoldify_artistic": + model_path = resolve_relative_path('../models/Frame/deoldify_artistic.onnx') + elif self.prev_type == "deoldify_stable": + model_path = resolve_relative_path('../models/Frame/deoldify_stable.onnx') + + onnxruntime.set_default_logger_severity(3) + self.model_colorizer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_colorizer.get_inputs() + model_outputs = self.model_colorizer.get_outputs() + self.io_binding = self.model_colorizer.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, input_frame: Frame) -> Frame: + temp_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2GRAY) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) + temp_frame = cv2.resize(temp_frame, (256, 256)) + temp_frame = temp_frame.transpose((2, 0, 1)) + temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32) + self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) + self.model_colorizer.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + colorized_frame = result.transpose(1, 2, 0) + colorized_frame = cv2.resize(colorized_frame, (input_frame.shape[1], input_frame.shape[0])) + temp_blue_channel, _, _ = cv2.split(input_frame) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB) + _, color_green_channel, color_red_channel = cv2.split(colorized_frame) + colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel)) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR) + return colorized_frame.astype(np.uint8) + + + def Release(self): + del self.model_colorizer + self.model_colorizer = None + del self.io_binding + self.io_binding = None + diff --git a/roop/processors/Frame_Filter.py b/roop/processors/Frame_Filter.py new file mode 100644 index 0000000000000000000000000000000000000000..b1405c329167a4e7f4f926ade5cf06ab6166466f --- /dev/null +++ b/roop/processors/Frame_Filter.py @@ -0,0 +1,105 @@ +import cv2 +import numpy as np + +from roop.typing import Frame + +class Frame_Filter(): + processorname = 'generic_filter' + type = 'frame_processor' + + plugin_options:dict = None + + c64_palette = np.array([ + [0, 0, 0], + [255, 255, 255], + [0x81, 0x33, 0x38], + [0x75, 0xce, 0xc8], + [0x8e, 0x3c, 0x97], + [0x56, 0xac, 0x4d], + [0x2e, 0x2c, 0x9b], + [0xed, 0xf1, 0x71], + [0x8e, 0x50, 0x29], + [0x55, 0x38, 0x00], + [0xc4, 0x6c, 0x71], + [0x4a, 0x4a, 0x4a], + [0x7b, 0x7b, 0x7b], + [0xa9, 0xff, 0x9f], + [0x70, 0x6d, 0xeb], + [0xb2, 0xb2, 0xb2] + ]) + + + def RenderC64Screen(self, image): + # Simply round the color values to the nearest color in the palette + image = cv2.resize(image,(320,200)) + palette = self.c64_palette / 255.0 # Normalize palette + img_normalized = image / 255.0 # Normalize image + + # Calculate the index in the palette that is closest to each pixel in the image + indices = np.sqrt(((img_normalized[:, :, None, :] - palette[None, None, :, :]) ** 2).sum(axis=3)).argmin(axis=2) + # Map the image to the palette colors + mapped_image = palette[indices] + return (mapped_image * 255).astype(np.uint8) # Denormalize and return the image + + + def RenderDetailEnhance(self, image): + return cv2.detailEnhance(image) + + def RenderStylize(self, image): + return cv2.stylization(image) + + def RenderPencilSketch(self, image): + imgray, imout = cv2.pencilSketch(image, sigma_s=60, sigma_r=0.07, shade_factor=0.05) + return imout + + def RenderCartoon(self, image): + numDownSamples = 2 # number of downscaling steps + numBilateralFilters = 7 # number of bilateral filtering steps + + img_color = image + for _ in range(numDownSamples): + img_color = cv2.pyrDown(img_color) + for _ in range(numBilateralFilters): + img_color = cv2.bilateralFilter(img_color, 9, 9, 7) + for _ in range(numDownSamples): + img_color = cv2.pyrUp(img_color) + img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + img_blur = cv2.medianBlur(img_gray, 7) + img_edge = cv2.adaptiveThreshold(img_blur, 255, + cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2) + img_edge = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2RGB) + if img_color.shape != image.shape: + img_color = cv2.resize(img_color, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR) + if img_color.shape != img_edge.shape: + img_edge = cv2.resize(img_edge, (img_color.shape[1], img_color.shape[0]), interpolation=cv2.INTER_LINEAR) + return cv2.bitwise_and(img_color, img_edge) + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + self.plugin_options = plugin_options + + def Run(self, temp_frame: Frame) -> Frame: + subtype = self.plugin_options["subtype"] + if subtype == "stylize": + return self.RenderStylize(temp_frame).astype(np.uint8) + if subtype == "detailenhance": + return self.RenderDetailEnhance(temp_frame).astype(np.uint8) + if subtype == "pencil": + return self.RenderPencilSketch(temp_frame).astype(np.uint8) + if subtype == "cartoon": + return self.RenderCartoon(temp_frame).astype(np.uint8) + if subtype == "C64": + return self.RenderC64Screen(temp_frame).astype(np.uint8) + + + def Release(self): + pass + + def getProcessedResolution(self, width, height): + if self.plugin_options["subtype"] == "C64": + return (320,200) + return None + diff --git a/roop/processors/Frame_Masking.py b/roop/processors/Frame_Masking.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4e77fec51854fc67c5274193665fd3555c24bb --- /dev/null +++ b/roop/processors/Frame_Masking.py @@ -0,0 +1,71 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + +class Frame_Masking(): + plugin_options:dict = None + model_masking = None + devicename = None + name = None + + processorname = 'removebg' + type = 'frame_masking' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_masking is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"] + self.devicename = self.devicename.replace('mps', 'cpu') + model_path = resolve_relative_path('../models/Frame/isnet-general-use.onnx') + self.model_masking = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_masking.get_inputs() + model_outputs = self.model_masking.get_outputs() + self.io_binding = self.model_masking.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, temp_frame: Frame) -> Frame: + # Pre process:Resize, BGR->RGB, float32 cast + input_image = cv2.resize(temp_frame, (1024, 1024)) + input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) + mean = [0.5, 0.5, 0.5] + std = [1.0, 1.0, 1.0] + input_image = (input_image / 255.0 - mean) / std + input_image = input_image.transpose(2, 0, 1) + input_image = np.expand_dims(input_image, axis=0) + input_image = input_image.astype('float32') + + self.io_binding.bind_cpu_input(self.model_inputs[0].name, input_image) + self.model_masking.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + # Post process:squeeze, Sigmoid, Normarize, uint8 cast + mask = np.squeeze(result[0]) + min_value = np.min(mask) + max_value = np.max(mask) + mask = (mask - min_value) / (max_value - min_value) + #mask = np.where(mask < score_th, 0, 1) + #mask *= 255 + mask = cv2.resize(mask, (temp_frame.shape[1], temp_frame.shape[0]), interpolation=cv2.INTER_LINEAR) + mask = np.reshape(mask, [mask.shape[0],mask.shape[1],1]) + result = mask * temp_frame.astype(np.float32) + return result.astype(np.uint8) + + + + def Release(self): + del self.model_masking + self.model_masking = None + del self.io_binding + self.io_binding = None + diff --git a/roop/processors/Frame_Upscale.py b/roop/processors/Frame_Upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..e323e98eee7cea6662a6426eb12ebc6a8b753974 --- /dev/null +++ b/roop/processors/Frame_Upscale.py @@ -0,0 +1,131 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals +import threading + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + +class Frame_Upscale(): + plugin_options:dict = None + model_upscale = None + devicename = None + prev_type = None + + processorname = 'upscale' + type = 'frame_enhancer' + + THREAD_LOCK_UPSCALE = threading.Lock() + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]: + self.Release() + self.prev_type = self.plugin_options["subtype"] + if self.model_upscale is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + if self.prev_type == "esrganx4": + model_path = resolve_relative_path('../models/Frame/real_esrgan_x4.onnx') + self.scale = 4 + elif self.prev_type == "esrganx2": + model_path = resolve_relative_path('../models/Frame/real_esrgan_x2.onnx') + self.scale = 2 + elif self.prev_type == "lsdirx4": + model_path = resolve_relative_path('../models/Frame/lsdir_x4.onnx') + self.scale = 4 + + self.model_upscale = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_upscale.get_inputs() + model_outputs = self.model_upscale.get_outputs() + self.io_binding = self.model_upscale.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def getProcessedResolution(self, width, height): + return (width * self.scale, height * self.scale) + +# borrowed from facefusion -> https://github.com/facefusion/facefusion + def prepare_tile_frame(self, tile_frame : Frame) -> Frame: + tile_frame = np.expand_dims(tile_frame[:, :, ::-1], axis = 0) + tile_frame = tile_frame.transpose(0, 3, 1, 2) + tile_frame = tile_frame.astype(np.float32) / 255 + return tile_frame + + + def normalize_tile_frame(self, tile_frame : Frame) -> Frame: + tile_frame = tile_frame.transpose(0, 2, 3, 1).squeeze(0) * 255 + tile_frame = tile_frame.clip(0, 255).astype(np.uint8)[:, :, ::-1] + return tile_frame + + def create_tile_frames(self, input_frame : Frame, size): + input_frame = np.pad(input_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0))) + tile_width = size[0] - 2 * size[2] + pad_size_bottom = size[2] + tile_width - input_frame.shape[0] % tile_width + pad_size_right = size[2] + tile_width - input_frame.shape[1] % tile_width + pad_vision_frame = np.pad(input_frame, ((size[2], pad_size_bottom), (size[2], pad_size_right), (0, 0))) + pad_height, pad_width = pad_vision_frame.shape[:2] + row_range = range(size[2], pad_height - size[2], tile_width) + col_range = range(size[2], pad_width - size[2], tile_width) + tile_frames = [] + + for row_frame in row_range: + top = row_frame - size[2] + bottom = row_frame + size[2] + tile_width + for column_vision_frame in col_range: + left = column_vision_frame - size[2] + right = column_vision_frame + size[2] + tile_width + tile_frames.append(pad_vision_frame[top:bottom, left:right, :]) + return tile_frames, pad_width, pad_height + + + def merge_tile_frames(self, tile_frames, temp_width : int, temp_height : int, pad_width : int, pad_height : int, size) -> Frame: + merge_frame = np.zeros((pad_height, pad_width, 3)).astype(np.uint8) + tile_width = tile_frames[0].shape[1] - 2 * size[2] + tiles_per_row = min(pad_width // tile_width, len(tile_frames)) + + for index, tile_frame in enumerate(tile_frames): + tile_frame = tile_frame[size[2]:-size[2], size[2]:-size[2]] + row_index = index // tiles_per_row + col_index = index % tiles_per_row + top = row_index * tile_frame.shape[0] + bottom = top + tile_frame.shape[0] + left = col_index * tile_frame.shape[1] + right = left + tile_frame.shape[1] + merge_frame[top:bottom, left:right, :] = tile_frame + merge_frame = merge_frame[size[1] : size[1] + temp_height, size[1]: size[1] + temp_width, :] + return merge_frame + + + def Run(self, temp_frame: Frame) -> Frame: + size = (128, 8, 2) + temp_height, temp_width = temp_frame.shape[:2] + upscale_tile_frames, pad_width, pad_height = self.create_tile_frames(temp_frame, size) + + for index, tile_frame in enumerate(upscale_tile_frames): + tile_frame = self.prepare_tile_frame(tile_frame) + with self.THREAD_LOCK_UPSCALE: + self.io_binding.bind_cpu_input(self.model_inputs[0].name, tile_frame) + self.model_upscale.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0] + upscale_tile_frames[index] = self.normalize_tile_frame(result) + final_frame = self.merge_tile_frames(upscale_tile_frames, temp_width * self.scale + , temp_height * self.scale + , pad_width * self.scale, pad_height * self.scale + , (size[0] * self.scale, size[1] * self.scale, size[2] * self.scale)) + return final_frame.astype(np.uint8) + + + + def Release(self): + del self.model_upscale + self.model_upscale = None + del self.io_binding + self.io_binding = None + diff --git a/roop/processors/Mask_Clip2Seg.py b/roop/processors/Mask_Clip2Seg.py new file mode 100644 index 0000000000000000000000000000000000000000..5df3b3e37ea10eb2440828a08e129d8c62f98086 --- /dev/null +++ b/roop/processors/Mask_Clip2Seg.py @@ -0,0 +1,94 @@ +import cv2 +import numpy as np +import torch +import threading +from torchvision import transforms +from clip.clipseg import CLIPDensePredT +import numpy as np + +from roop.typing import Frame + +THREAD_LOCK_CLIP = threading.Lock() + + +class Mask_Clip2Seg(): + plugin_options:dict = None + model_clip = None + + processorname = 'clip2seg' + type = 'mask' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_clip is None: + self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True) + self.model_clip.eval(); + self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False) + + device = torch.device(self.plugin_options["devicename"]) + self.model_clip.to(device) + + + def Run(self, img1, keywords:str) -> Frame: + if keywords is None or len(keywords) < 1 or img1 is None: + return img1 + + source_image_small = cv2.resize(img1, (256,256)) + + img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32) + mask_border = 1 + l = 0 + t = 0 + r = 1 + b = 1 + + mask_blur = 5 + clip_blur = 5 + + img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)), + (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1) + img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0) + img_mask /= 255 + + + input_image = source_image_small + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + transforms.Resize((256, 256)), + ]) + img = transform(input_image).unsqueeze(0) + + thresh = 0.5 + prompts = keywords.split(',') + with THREAD_LOCK_CLIP: + with torch.no_grad(): + preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0] + clip_mask = torch.sigmoid(preds[0][0]) + for i in range(len(prompts)-1): + clip_mask += torch.sigmoid(preds[i+1][0]) + + clip_mask = clip_mask.data.cpu().numpy() + np.clip(clip_mask, 0, 1) + + clip_mask[clip_mask>thresh] = 1.0 + clip_mask[clip_mask<=thresh] = 0.0 + kernel = np.ones((5, 5), np.float32) + clip_mask = cv2.dilate(clip_mask, kernel, iterations=1) + clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0) + + img_mask *= clip_mask + img_mask[img_mask<0.0] = 0.0 + return img_mask + + + + def Release(self): + self.model_clip = None + diff --git a/roop/processors/Mask_XSeg.py b/roop/processors/Mask_XSeg.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8e87741c9aa99cde84aa20566bb8c3db548fe2 --- /dev/null +++ b/roop/processors/Mask_XSeg.py @@ -0,0 +1,60 @@ +import numpy as np +import cv2 +import onnxruntime +import threading +import roop.globals + +from roop.typing import Frame +from roop.utilities import resolve_relative_path + +THREAD_LOCK_CLIP = threading.Lock() + + +class Mask_XSeg(): + plugin_options:dict = None + + model_xseg = None + + processorname = 'mask_xseg' + type = 'mask' + + + def Initialize(self, plugin_options:dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_xseg is None: + model_path = resolve_relative_path('../models/xseg.onnx') + onnxruntime.set_default_logger_severity(3) + self.model_xseg = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers) + self.model_inputs = self.model_xseg.get_inputs() + self.model_outputs = self.model_xseg.get_outputs() + + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu') + + + def Run(self, img1, keywords:str) -> Frame: + temp_frame = cv2.resize(img1, (256, 256), cv2.INTER_CUBIC) + temp_frame = temp_frame.astype('float32') / 255.0 + temp_frame = temp_frame[None, ...] + io_binding = self.model_xseg.io_binding() + io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) + io_binding.bind_output(self.model_outputs[0].name, self.devicename) + self.model_xseg.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + result = np.clip(result, 0, 1.0) + result[result < 0.1] = 0 + # invert values to mask areas to keep + result = 1.0 - result + return result + + + def Release(self): + del self.model_xseg + self.model_xseg = None + + diff --git a/roop/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc b/roop/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08b2a9fd8cf366f9f56c0612b91034e1b0025631 Binary files /dev/null and b/roop/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc differ diff --git a/roop/processors/__pycache__/Mask_XSeg.cpython-310.pyc b/roop/processors/__pycache__/Mask_XSeg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e23db6f39b0eea20125f3a99ff8aba1116df79 Binary files /dev/null and b/roop/processors/__pycache__/Mask_XSeg.cpython-310.pyc differ diff --git a/roop/processors/__pycache__/__init__.cpython-310.pyc b/roop/processors/__pycache__/__init__.cpython-310.pyc index 63d8eb4aa92bc0bdd8ba9fdeb374697a9c663b1a..b477fce10f565769459daa9c46b19ba286c59529 100644 Binary files a/roop/processors/__pycache__/__init__.cpython-310.pyc and b/roop/processors/__pycache__/__init__.cpython-310.pyc differ diff --git a/roop/template_parser.py b/roop/template_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..a51113b69830119fc84fd15c2a428321ac1d8010 --- /dev/null +++ b/roop/template_parser.py @@ -0,0 +1,23 @@ +import re +from datetime import datetime + +template_functions = { + "timestamp": lambda data: str(int(datetime.now().timestamp())), + "i": lambda data: data.get("index", False), + "file": lambda data: data.get("file", False), + "date": lambda data: datetime.now().strftime("%Y-%m-%d"), + "time": lambda data: datetime.now().strftime("%H-%M-%S"), +} + + +def parse(text: str, data: dict): + pattern = r"\{([^}]+)\}" + + matches = re.findall(pattern, text) + + for match in matches: + replacement = template_functions[match](data) + if replacement is not False: + text = text.replace(f"{{{match}}}", replacement) + + return text diff --git a/roop/typing.py b/roop/typing.py index 1cff7440616e20bfe7b8bc287f86d11bf1b0f083..263f1b5b0331332dfab9f682438b364c612cfdf8 100644 --- a/roop/typing.py +++ b/roop/typing.py @@ -1,7 +1,9 @@ from typing import Any from insightface.app.common import Face +from roop.FaceSet import FaceSet import numpy Face = Face +FaceSet = FaceSet Frame = numpy.ndarray[Any, Any] diff --git a/roop/util_ffmpeg.py b/roop/util_ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..8b8c9a978f2acfd647c5e1088c0264e9193b68be --- /dev/null +++ b/roop/util_ffmpeg.py @@ -0,0 +1,112 @@ + +import os +import subprocess +import roop.globals +import roop.utilities as util + +from typing import List, Any + +def run_ffmpeg(args: List[str]) -> bool: + commands = ['ffmpeg', '-hide_banner', '-hwaccel', 'auto', '-y', '-loglevel', roop.globals.log_level] + commands.extend(args) + print("Running ffmpeg") + try: + subprocess.check_output(commands, stderr=subprocess.STDOUT) + return True + except Exception as e: + print("Running ffmpeg failed! Commandline:") + print(" ".join(map(str, commands))) # Ensure all elements are strings + print(e) + return False + + + +def cut_video(original_video: str, cut_video: str, start_frame: int, end_frame: int, reencode: bool): + fps = util.detect_fps(original_video) + start_time = start_frame / fps + num_frames = end_frame - start_frame + + if reencode: + run_ffmpeg(['-ss', format(start_time, ".2f"), '-i', original_video, '-c:v', roop.globals.video_encoder, '-c:a', 'aac', '-frames:v', str(num_frames), cut_video]) + else: + run_ffmpeg(['-ss', format(start_time, ".2f"), '-i', original_video, '-frames:v', str(num_frames), '-c:v' ,'copy','-c:a' ,'copy', cut_video]) + +def join_videos(videos: List[str], dest_filename: str, simple: bool): + if simple: + txtfilename = util.resolve_relative_path('../temp') + txtfilename = os.path.join(txtfilename, 'joinvids.txt') + with open(txtfilename, "w", encoding="utf-8") as f: + for v in videos: + v = v.replace('\\', '/') + f.write(f"file {v}\n") + commands = ['-f', 'concat', '-safe', '0', '-i', f'{txtfilename}', '-vcodec', 'copy', f'{dest_filename}'] + run_ffmpeg(commands) + + else: + inputs = [] + filter = '' + for i,v in enumerate(videos): + inputs.append('-i') + inputs.append(v) + filter += f'[{i}:v:0][{i}:a:0]' + run_ffmpeg([" ".join(inputs), '-filter_complex', f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', '-map', '"[outv]"', '-map', '"[outa]"', dest_filename]) + + # filter += f'[{i}:v:0][{i}:a:0]' + # run_ffmpeg([" ".join(inputs), '-filter_complex', f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', '-map', '"[outv]"', '-map', '"[outa]"', dest_filename]) + + + +def extract_frames(target_path : str, trim_frame_start, trim_frame_end, fps : float) -> bool: + util.create_temp(target_path) + temp_directory_path = util.get_temp_directory_path(target_path) + commands = ['-i', target_path, '-q:v', '1', '-pix_fmt', 'rgb24', ] + if trim_frame_start is not None and trim_frame_end is not None: + commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ':end_frame=' + str(trim_frame_end) + ',fps=' + str(fps) ]) + commands.extend(['-vsync', '0', os.path.join(temp_directory_path, '%06d.' + roop.globals.CFG.output_image_format)]) + return run_ffmpeg(commands) + + +def create_video(target_path: str, dest_filename: str, fps: float = 24.0, temp_directory_path: str = None) -> None: + if temp_directory_path is None: + temp_directory_path = util.get_temp_directory_path(target_path) + print("dest file name is " + dest_filename) + run_ffmpeg(['-r', str(fps), '-i', os.path.join(temp_directory_path, f'%06d.{roop.globals.CFG.output_image_format}'), '-c:v', roop.globals.video_encoder, '-crf', str(roop.globals.video_quality), '-pix_fmt', 'yuv420p', '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1', '-y', dest_filename]) + return dest_filename + + +def create_gif_from_video(video_path: str, gif_path): + from roop.capturer import get_video_frame + + fps = util.detect_fps(video_path) + frame = get_video_frame(video_path) + + run_ffmpeg(['-i', video_path, '-vf', f'fps={fps},scale={frame.shape[0]}:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse', '-loop', '0', gif_path]) + + +def restore_audio(intermediate_video: str, original_video: str, trim_frame_start, trim_frame_end, final_video : str) -> None: + fps = util.detect_fps(original_video) + commands = [ '-i', intermediate_video ] + if trim_frame_start is None and trim_frame_end is None: + commands.extend([ '-c:a', 'copy' ]) + else: + # if trim_frame_start is not None: + # start_time = trim_frame_start / fps + # commands.extend([ '-ss', format(start_time, ".2f")]) + # else: + # commands.extend([ '-ss', '0' ]) + # if trim_frame_end is not None: + # end_time = trim_frame_end / fps + # commands.extend([ '-to', format(end_time, ".2f")]) + # commands.extend([ '-c:a', 'aac' ]) + if trim_frame_start is not None: + start_time = trim_frame_start / fps + commands.extend([ '-ss', format(start_time, ".2f")]) + else: + commands.extend([ '-ss', '0' ]) + if trim_frame_end is not None: + end_time = trim_frame_end / fps + commands.extend([ '-to', format(end_time, ".2f")]) + commands.extend([ '-i', original_video, "-c", "copy" ]) + + commands.extend([ '-map', '0:v:0', '-map', '1:a:0?', '-shortest', final_video ]) + run_ffmpeg(commands) diff --git a/roop/utilities.py b/roop/utilities.py index 90c8d981f5f159a459ca0c08cc23dfac8d04c068..3a7a41c079e0d71e81c9d3326100a598471c79a0 100644 --- a/roop/utilities.py +++ b/roop/utilities.py @@ -5,64 +5,90 @@ import platform import shutil import ssl import subprocess +import sys import urllib +import torch +import gradio +import tempfile +import cv2 +import zipfile +import traceback + from pathlib import Path from typing import List, Any from tqdm import tqdm +from scipy.spatial import distance + +import roop.template_parser as template_parser import roop.globals -TEMP_FILE = 'temp.mp4' -TEMP_DIRECTORY = 'temp' +TEMP_FILE = "temp.mp4" +TEMP_DIRECTORY = "temp" # monkey patch ssl for mac -if platform.system().lower() == 'darwin': +if platform.system().lower() == "darwin": ssl._create_default_https_context = ssl._create_unverified_context -def run_ffmpeg(args: List[str]) -> bool: - commands = ['ffmpeg', '-hide_banner', '-hwaccel', 'auto', '-loglevel', roop.globals.log_level] - commands.extend(args) - try: - subprocess.check_output(commands, stderr=subprocess.STDOUT) - return True - except Exception: - pass - return False - - +# https://github.com/facefusion/facefusion/blob/master/facefusion def detect_fps(target_path: str) -> float: - command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=r_frame_rate', '-of', 'default=noprint_wrappers=1:nokey=1', target_path] - output = subprocess.check_output(command).decode().strip().split('/') - try: - numerator, denominator = map(int, output) - return numerator / denominator - except Exception: - pass - return 30.0 - - -def extract_frames(target_path: str) -> None: - temp_directory_path = get_temp_directory_path(target_path) - run_ffmpeg(['-i', target_path, '-pix_fmt', 'rgb24', os.path.join(temp_directory_path, '%04d.png')]) - - -def create_video(target_path: str, fps: float = 30.0) -> None: - temp_output_path = get_temp_output_path(target_path) - temp_directory_path = get_temp_directory_path(target_path) - run_ffmpeg(['-r', str(fps), '-i', os.path.join(temp_directory_path, '%04d.png'), '-c:v', roop.globals.video_encoder, '-crf', str(roop.globals.video_quality), '-pix_fmt', 'yuv420p', '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1', '-y', temp_output_path]) - - -def restore_audio(target_path: str, output_path: str) -> None: - temp_output_path = get_temp_output_path(target_path) - done = run_ffmpeg(['-i', temp_output_path, '-i', target_path, '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0', '-y', output_path]) - if not done: - move_temp(target_path, output_path) + fps = 24.0 + cap = cv2.VideoCapture(target_path) + if cap.isOpened(): + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + return fps + + +# Gradio wants Images in RGB +def convert_to_gradio(image): + if image is None: + return None + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +def sort_filenames_ignore_path(filenames): + """Sorts a list of filenames containing a complete path by their filename, + while retaining their original path. + + Args: + filenames: A list of filenames containing a complete path. + + Returns: + A sorted list of filenames containing a complete path. + """ + filename_path_tuples = [ + (os.path.split(filename)[1], filename) for filename in filenames + ] + sorted_filename_path_tuples = sorted(filename_path_tuples, key=lambda x: x[0]) + return [ + filename_path_tuple[1] for filename_path_tuple in sorted_filename_path_tuples + ] + + +def sort_rename_frames(path: str): + filenames = os.listdir(path) + filenames.sort() + for i in range(len(filenames)): + of = os.path.join(path, filenames[i]) + newidx = i + 1 + new_filename = os.path.join( + path, f"{newidx:06d}." + roop.globals.CFG.output_image_format + ) + os.rename(of, new_filename) def get_temp_frame_paths(target_path: str) -> List[str]: temp_directory_path = get_temp_directory_path(target_path) - return glob.glob((os.path.join(glob.escape(temp_directory_path), '*.png'))) + return glob.glob( + ( + os.path.join( + glob.escape(temp_directory_path), + f"*.{roop.globals.CFG.output_image_format}", + ) + ) + ) def get_temp_directory_path(target_path: str) -> str: @@ -81,10 +107,35 @@ def normalize_output_path(source_path: str, target_path: str, output_path: str) source_name, _ = os.path.splitext(os.path.basename(source_path)) target_name, target_extension = os.path.splitext(os.path.basename(target_path)) if os.path.isdir(output_path): - return os.path.join(output_path, source_name + '-' + target_name + target_extension) + return os.path.join( + output_path, source_name + "-" + target_name + target_extension + ) return output_path +def get_destfilename_from_path( + srcfilepath: str, destfilepath: str, extension: str +) -> str: + fn, ext = os.path.splitext(os.path.basename(srcfilepath)) + if "." in extension: + return os.path.join(destfilepath, f"{fn}{extension}") + return os.path.join(destfilepath, f"{fn}{extension}{ext}") + + +def replace_template(file_path: str, index: int = 0) -> str: + fn, ext = os.path.splitext(os.path.basename(file_path)) + + # Remove the "__temp" placeholder that was used as a temporary filename + fn = fn.replace("__temp", "") + + template = roop.globals.CFG.output_template + replaced_filename = template_parser.parse( + template, {"index": str(index), "file": fn} + ) + + return os.path.join(roop.globals.output_path, f"{replaced_filename}{ext}") + + def create_temp(target_path: str) -> None: temp_directory_path = get_temp_directory_path(target_path) Path(temp_directory_path).mkdir(parents=True, exist_ok=True) @@ -107,21 +158,30 @@ def clean_temp(target_path: str) -> None: os.rmdir(parent_directory_path) +def delete_temp_frames(filename: str) -> None: + dir = os.path.dirname(os.path.dirname(filename)) + shutil.rmtree(dir) + + def has_image_extension(image_path: str) -> bool: - return image_path.lower().endswith(('png', 'jpg', 'jpeg', 'webp')) + return image_path.lower().endswith(("png", "jpg", "jpeg", "webp")) + + +def has_extension(filepath: str, extensions: List[str]) -> bool: + return filepath.lower().endswith(tuple(extensions)) def is_image(image_path: str) -> bool: if image_path and os.path.isfile(image_path): mimetype, _ = mimetypes.guess_type(image_path) - return bool(mimetype and mimetype.startswith('image/')) + return bool(mimetype and mimetype.startswith("image/")) return False def is_video(video_path: str) -> bool: if video_path and os.path.isfile(video_path): mimetype, _ = mimetypes.guess_type(video_path) - return bool(mimetype and mimetype.startswith('video/')) + return bool(mimetype and mimetype.startswith("video/")) return False @@ -129,13 +189,151 @@ def conditional_download(download_directory_path: str, urls: List[str]) -> None: if not os.path.exists(download_directory_path): os.makedirs(download_directory_path) for url in urls: - download_file_path = os.path.join(download_directory_path, os.path.basename(url)) + download_file_path = os.path.join( + download_directory_path, os.path.basename(url) + ) if not os.path.exists(download_file_path): - request = urllib.request.urlopen(url) # type: ignore[attr-defined] - total = int(request.headers.get('Content-Length', 0)) - with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress: - urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined] + request = urllib.request.urlopen(url) # type: ignore[attr-defined] + total = int(request.headers.get("Content-Length", 0)) + with tqdm( + total=total, + desc=f"Downloading {url}", + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as progress: + urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined] + + +def get_local_files_from_folder(folder: str) -> List[str]: + if not os.path.exists(folder) or not os.path.isdir(folder): + return None + files = [ + os.path.join(folder, f) + for f in os.listdir(folder) + if os.path.isfile(os.path.join(folder, f)) + ] + return files def resolve_relative_path(path: str) -> str: return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) + + +def get_device() -> str: + if len(roop.globals.execution_providers) < 1: + roop.globals.execution_providers = ["CPUExecutionProvider"] + + prov = roop.globals.execution_providers[0] + if "CoreMLExecutionProvider" in prov: + return "mps" + if "CUDAExecutionProvider" in prov or "ROCMExecutionProvider" in prov: + return "cuda" + if "OpenVINOExecutionProvider" in prov: + return "mkl" + return "cpu" + + +def str_to_class(module_name, class_name) -> Any: + from importlib import import_module + + class_ = None + try: + module_ = import_module(module_name) + try: + class_ = getattr(module_, class_name)() + except AttributeError: + print(f"Class {class_name} does not exist") + except ImportError: + print(f"Module {module_name} does not exist") + return class_ + +def is_installed(name:str) -> bool: + return shutil.which(name); + +# Taken from https://stackoverflow.com/a/68842705 +def get_platform() -> str: + if sys.platform == "linux": + try: + proc_version = open("/proc/version").read() + if "Microsoft" in proc_version: + return "wsl" + except: + pass + return sys.platform + +def open_with_default_app(filename:str): + if filename == None: + return + platform = get_platform() + if platform == "darwin": + subprocess.call(("open", filename)) + elif platform in ["win64", "win32"]: os.startfile(filename.replace("/", "\\")) + elif platform == "wsl": + subprocess.call("cmd.exe /C start".split() + [filename]) + else: # linux variants + subprocess.call("xdg-open", filename) + + +def prepare_for_batch(target_files) -> str: + print("Preparing temp files") + tempfolder = os.path.join(tempfile.gettempdir(), "rooptmp") + if os.path.exists(tempfolder): + shutil.rmtree(tempfolder) + Path(tempfolder).mkdir(parents=True, exist_ok=True) + for f in target_files: + newname = os.path.basename(f.name) + shutil.move(f.name, os.path.join(tempfolder, newname)) + return tempfolder + + +def zip(files, zipname): + with zipfile.ZipFile(zipname, "w") as zip_file: + for f in files: + zip_file.write(f, os.path.basename(f)) + + +def unzip(zipfilename: str, target_path: str): + with zipfile.ZipFile(zipfilename, "r") as zip_file: + zip_file.extractall(target_path) + + +def mkdir_with_umask(directory): + oldmask = os.umask(0) + # mode needs octal + os.makedirs(directory, mode=0o775, exist_ok=True) + os.umask(oldmask) + + +def open_folder(path: str): + platform = get_platform() + try: + if platform == "darwin": + subprocess.call(("open", path)) + elif platform in ["win64", "win32"]: + open_with_default_app(path) + elif platform == "wsl": + subprocess.call("cmd.exe /C start".split() + [path]) + else: # linux variants + subprocess.Popen(["xdg-open", path]) + except Exception as e: + traceback.print_exc() + pass + # import webbrowser + # webbrowser.open(url) + + +def create_version_html() -> str: + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + versions_html = f""" +python: {python_version} +โ€ข +torch: {getattr(torch, '__long_version__',torch.__version__)} +โ€ข +gradio: {gradio.__version__} +""" + return versions_html + + +def compute_cosine_distance(emb1, emb2) -> float: + return distance.cosine(emb1, emb2) diff --git a/roop/virtualcam.py b/roop/virtualcam.py new file mode 100644 index 0000000000000000000000000000000000000000..d429851bb610789386a4d11866d2663f43bd78be --- /dev/null +++ b/roop/virtualcam.py @@ -0,0 +1,87 @@ +import cv2 +import roop.globals +import ui.globals +import pyvirtualcam +import threading +import platform + + +cam_active = False +cam_thread = None +vcam = None + +def virtualcamera(streamobs, cam_num,width,height): + from roop.ProcessOptions import ProcessOptions + from roop.core import live_swap, get_processing_plugins + + global cam_active + + #time.sleep(2) + print('Starting capture') + cap = cv2.VideoCapture(cam_num, cv2.CAP_DSHOW if platform.system() != 'Darwin' else cv2.CAP_AVFOUNDATION) + if not cap.isOpened(): + print("Cannot open camera") + cap.release() + del cap + return + + pref_width = width + pref_height = height + pref_fps_in = 30 + cap.set(cv2.CAP_PROP_FRAME_WIDTH, pref_width) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, pref_height) + cap.set(cv2.CAP_PROP_FPS, pref_fps_in) + cam_active = True + + # native format UYVY + + cam = None + if streamobs: + print('Detecting virtual cam devices') + cam = pyvirtualcam.Camera(width=pref_width, height=pref_height, fps=pref_fps_in, fmt=pyvirtualcam.PixelFormat.BGR, print_fps=False) + if cam: + print(f'Using virtual camera: {cam.device}') + print(f'Using {cam.native_fmt}') + else: + print(f'Not streaming to virtual camera!') + + # always use xseg masking + options = ProcessOptions(get_processing_plugins("mask_xseg"), roop.globals.distance_threshold, roop.globals.blend_ratio, + "all", 0, None, None, 1, False) + while cam_active: + ret, frame = cap.read() + if not ret: + break + + if len(roop.globals.INPUT_FACESETS) > 0: + frame = live_swap(frame, options) + if cam: + cam.send(frame) + cam.sleep_until_next_frame() + ui.globals.ui_camera_frame = frame + + if cam: + cam.close() + cap.release() + print('Camera stopped') + + + +def start_virtual_cam(streamobs, cam_number, resolution): + global cam_thread, cam_active + + if not cam_active: + width, height = map(int, resolution.split('x')) + cam_thread = threading.Thread(target=virtualcamera, args=[streamobs, cam_number, width, height]) + cam_thread.start() + + + +def stop_virtual_cam(): + global cam_active, cam_thread + + if cam_active: + cam_active = False + cam_thread.join() + + diff --git a/roop/vr_util.py b/roop/vr_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a72845e3c2c3cc89f6567ebfc13bf77d306710ff --- /dev/null +++ b/roop/vr_util.py @@ -0,0 +1,57 @@ +import cv2 +import numpy as np + +# VR Lense Distortion +# Taken from https://github.com/g0kuvonlange/vrswap + + +def get_perspective(img, FOV, THETA, PHI, height, width): + # + # THETA is left/right angle, PHI is up/down angle, both in degree + # + [orig_width, orig_height, _] = img.shape + equ_h = orig_height + equ_w = orig_width + equ_cx = (equ_w - 1) / 2.0 + equ_cy = (equ_h - 1) / 2.0 + + wFOV = FOV + hFOV = float(height) / width * wFOV + + w_len = np.tan(np.radians(wFOV / 2.0)) + h_len = np.tan(np.radians(hFOV / 2.0)) + + x_map = np.ones([height, width], np.float32) + y_map = np.tile(np.linspace(-w_len, w_len, width), [height, 1]) + z_map = -np.tile(np.linspace(-h_len, h_len, height), [width, 1]).T + + D = np.sqrt(x_map**2 + y_map**2 + z_map**2) + xyz = np.stack((x_map, y_map, z_map), axis=2) / np.repeat( + D[:, :, np.newaxis], 3, axis=2 + ) + + y_axis = np.array([0.0, 1.0, 0.0], np.float32) + z_axis = np.array([0.0, 0.0, 1.0], np.float32) + [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA)) + [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(-PHI)) + + xyz = xyz.reshape([height * width, 3]).T + xyz = np.dot(R1, xyz) + xyz = np.dot(R2, xyz).T + lat = np.arcsin(xyz[:, 2]) + lon = np.arctan2(xyz[:, 1], xyz[:, 0]) + + lon = lon.reshape([height, width]) / np.pi * 180 + lat = -lat.reshape([height, width]) / np.pi * 180 + + lon = lon / 180 * equ_cx + equ_cx + lat = lat / 90 * equ_cy + equ_cy + + persp = cv2.remap( + img, + lon.astype(np.float32), + lat.astype(np.float32), + cv2.INTER_CUBIC, + borderMode=cv2.BORDER_WRAP, + ) + return persp diff --git a/run.py b/run.py new file mode 100755 index 0000000000000000000000000000000000000000..b52e5cc4a8ea9ce5cadd4e7111fb15531f380314 --- /dev/null +++ b/run.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from roop import core + +if __name__ == '__main__': + core.run() diff --git a/settings.py b/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..eaed8e0d33375c38c0bf44f0d79c96f0c646c36d --- /dev/null +++ b/settings.py @@ -0,0 +1,68 @@ +import yaml + +class Settings: + def __init__(self, config_file): + self.config_file = config_file + self.load() + + def default_get(_, data, name, default): + value = default + try: + value = data.get(name, default) + except: + pass + return value + + + def load(self): + try: + with open(self.config_file, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + except: + data = None + + self.selected_theme = self.default_get(data, 'selected_theme', "Default") + self.server_name = self.default_get(data, 'server_name', "") + self.server_port = self.default_get(data, 'server_port', 0) + self.server_share = self.default_get(data, 'server_share', False) + self.output_image_format = self.default_get(data, 'output_image_format', 'png') + self.output_video_format = self.default_get(data, 'output_video_format', 'mp4') + self.output_video_codec = self.default_get(data, 'output_video_codec', 'libx264') + self.video_quality = self.default_get(data, 'video_quality', 14) + self.clear_output = self.default_get(data, 'clear_output', True) + self.max_threads = self.default_get(data, 'max_threads', 2) + self.memory_limit = self.default_get(data, 'memory_limit', 0) + self.provider = self.default_get(data, 'provider', 'cuda') + self.force_cpu = self.default_get(data, 'force_cpu', False) + self.output_template = self.default_get(data, 'output_template', '{file}_{time}') + self.use_os_temp_folder = self.default_get(data, 'use_os_temp_folder', False) + self.output_show_video = self.default_get(data, 'output_show_video', True) + + + + + + def save(self): + data = { + 'selected_theme': self.selected_theme, + 'server_name': self.server_name, + 'server_port': self.server_port, + 'server_share': self.server_share, + 'output_image_format' : self.output_image_format, + 'output_video_format' : self.output_video_format, + 'output_video_codec' : self.output_video_codec, + 'video_quality' : self.video_quality, + 'clear_output' : self.clear_output, + 'max_threads' : self.max_threads, + 'memory_limit' : self.memory_limit, + 'provider' : self.provider, + 'force_cpu' : self.force_cpu, + 'output_template' : self.output_template, + 'use_os_temp_folder' : self.use_os_temp_folder, + 'output_show_video' : self.output_show_video + } + with open(self.config_file, 'w') as f: + yaml.dump(data, f) + + + diff --git a/train_log/.DS_Store b/train_log/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..198104583e60199ba4bf4eb0d324c0daadae3363 Binary files /dev/null and b/train_log/.DS_Store differ diff --git a/train_log/IFNet_HDv3.py b/train_log/IFNet_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e4cf8e196cbcf61527e5d710b8555e712caa49 --- /dev/null +++ b/train_log/IFNet_HDv3.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.warplayer import warp +# from train_log.refine import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True) + ) + +class Head(nn.Module): + def __init__(self): + super(Head, self).__init__() + self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) + self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) + self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x, feat=False): + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + if feat: + return [x0, x1, x2, x3] + return x3 + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\ +) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c//2, 3, 2, 1), + conv(c//2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4*6, 4, 2, 1), + nn.PixelShuffle(2) + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + return flow, mask + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7+16, c=192) + self.block1 = IFBlock(8+4+16, c=128) + self.block2 = IFBlock(8+4+16, c=96) + self.block3 = IFBlock(8+4+16, c=64) + self.encode = Head() + # self.contextnet = Contextnet() + # self.unet = Unet() + + def forward(self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + loss_cons = 0 + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i]) + if ensemble: + f_, m_ = block[i](torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1-timestep), 1), None, scale=scale_list[i]) + flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (mask + (-m_)) / 2 + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask), 1), flow, scale=scale_list[i]) + if ensemble: + f_, m_ = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], wf1, wf0, 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask = torch.sigmoid(mask) + merged[3] = (warped_img0 * mask + warped_img1 * (1 - mask)) + if not fastmode: + print('contextnet is removed') + ''' + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[3] = torch.clamp(merged[3] + res, 0, 1) + ''' + return flow_list, mask_list[3], merged diff --git a/train_log/RIFE_HDv3.py b/train_log/RIFE_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..897c1cc6468919fd08e11b135f446d8fa9ff7a37 --- /dev/null +++ b/train_log/RIFE_HDv3.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from train_log.IFNet_HDv3 import * +import torch.nn.functional as F +from model.loss import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.device() + self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.epe = EPE() + self.version = 4.8 + # self.vgg = VGGPerceptualLoss().to(device) + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank=0): + def convert(param): + if rank == -1: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return param + if rank <= 0: + if torch.cuda.is_available(): + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))), False) + else: + self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')), False) + + def save_model(self, path, rank=0): + if rank == 0: + torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) + + def inference(self, img0, img1, timestep=0.5, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [8/scale, 4/scale, 2/scale, 1/scale] + flow, mask, merged = self.flownet(imgs, timestep, scale_list) + return merged[3] + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group['lr'] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + scale = [8, 4, 2, 1] + flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + loss_l1 = (merged[3] - gt).abs().mean() + loss_smooth = self.sobel(flow[3], flow[3]*0).mean() + # loss_vgg = self.vgg(merged[2], gt) + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_cons + loss_smooth * 0.1 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[3], { + 'mask': mask, + 'flow': flow[3][:, :2], + 'loss_l1': loss_l1, + 'loss_cons': loss_cons, + 'loss_smooth': loss_smooth, + } diff --git a/train_log/__pycache__/IFNet_HDv3.cpython-310.pyc b/train_log/__pycache__/IFNet_HDv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d69ff6e1283580275df326bdb487fa571113ae64 Binary files /dev/null and b/train_log/__pycache__/IFNet_HDv3.cpython-310.pyc differ diff --git a/train_log/__pycache__/RIFE_HDv3.cpython-310.pyc b/train_log/__pycache__/RIFE_HDv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9324decdc7150c3da99274f3ca695c9ac893aa61 Binary files /dev/null and b/train_log/__pycache__/RIFE_HDv3.cpython-310.pyc differ diff --git a/train_log/flownet.pkl b/train_log/flownet.pkl new file mode 100644 index 0000000000000000000000000000000000000000..aa218a2b4b78f404cdb9d66aebed32113a8f7b6c --- /dev/null +++ b/train_log/flownet.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1ee3186270312a38316e4d53c77b31a60062cfa5636e13d6f0a1dd89bb7b128 +size 21508207 diff --git a/train_log/refine.py b/train_log/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..41b648ec12403f442f8bf0941bed9b0d896f2d87 --- /dev/null +++ b/train_log/refine.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.nn.functional as F + +device = torch.device("cuda") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.LeakyReLU(0.2, True) + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +c = 16 +class Contextnet(nn.Module): + def __init__(self): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2*c) + self.conv3 = Conv2(2*c, 4*c) + self.conv4 = Conv2(4*c, 8*c) + + def forward(self, x, flow): + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + +class Unet(nn.Module): + def __init__(self): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.Conv2d(c, 3, 3, 1, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/ui/globals.py b/ui/globals.py new file mode 100644 index 0000000000000000000000000000000000000000..5514a63d6e6e00bfb72938f8648e7eb5575d601a --- /dev/null +++ b/ui/globals.py @@ -0,0 +1,15 @@ +ui_restart_server = False + +SELECTION_FACES_DATA = None +ui_SELECTED_INPUT_FACE_INDEX = 0 + +ui_selected_enhancer = None +ui_blend_ratio = None +ui_input_thumbs = [] +ui_target_thumbs = [] +ui_camera_frame = None + + + + + diff --git a/ui/main.py b/ui/main.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf64b1b63f132119432cc795ca76fcb5b134200 --- /dev/null +++ b/ui/main.py @@ -0,0 +1,88 @@ +import os +import time +import gradio as gr +import roop.globals +import roop.metadata +import roop.utilities as util +import ui.globals as uii + +from ui.tabs.faceswap_tab import faceswap_tab +from ui.tabs.livecam_tab import livecam_tab +from ui.tabs.facemgr_tab import facemgr_tab +from ui.tabs.extras_tab import extras_tab +from ui.tabs.settings_tab import settings_tab + +roop.globals.keep_fps = None +roop.globals.keep_frames = None +roop.globals.skip_audio = None +roop.globals.use_batch = None + + +def prepare_environment(): + roop.globals.output_path = os.path.abspath(os.path.join(os.getcwd(), "output")) + os.makedirs(roop.globals.output_path, exist_ok=True) + if not roop.globals.CFG.use_os_temp_folder: + os.environ["TEMP"] = os.environ["TMP"] = os.path.abspath(os.path.join(os.getcwd(), "temp")) + os.makedirs(os.environ["TEMP"], exist_ok=True) + os.environ["GRADIO_TEMP_DIR"] = os.environ["TEMP"] + + +def run(): + from roop.core import decode_execution_providers, set_display_ui + + prepare_environment() + + set_display_ui(show_msg) + roop.globals.execution_providers = decode_execution_providers([roop.globals.CFG.provider]) + print(f'Using provider {roop.globals.execution_providers} - Device:{util.get_device()}') + + run_server = True + uii.ui_restart_server = False + mycss = """ + span {color: var(--block-info-text-color)} + #fixedheight { + max-height: 238.4px; + overflow-y: auto !important; + } + .image-container.svelte-1l6wqyv {height: 100%} + + """ + + while run_server: + server_name = roop.globals.CFG.server_name + if server_name is None or len(server_name) < 1: + server_name = None + server_port = roop.globals.CFG.server_port + if server_port <= 0: + server_port = None + ssl_verify = False if server_name == '0.0.0.0' else True + with gr.Blocks(title=f'{roop.metadata.name} {roop.metadata.version}', theme=roop.globals.CFG.selected_theme, css=mycss) as ui: + with gr.Row(variant='compact'): + gr.Markdown(f"### [{roop.metadata.name} {roop.metadata.version}](https://github.com/C0untFloyd/roop-unleashed)") + gr.HTML(util.create_version_html(), elem_id="versions") + faceswap_tab() + livecam_tab() + facemgr_tab() + extras_tab() + settings_tab() + + uii.ui_restart_server = False + try: + ui.queue().launch(inbrowser=True, server_name=server_name, server_port=server_port, share=roop.globals.CFG.server_share, ssl_verify=ssl_verify, prevent_thread_lock=True, show_error=True) + except Exception as e: + print(f'Exception {e} when launching Gradio Server!') + uii.ui_restart_server = True + run_server = False + try: + while uii.ui_restart_server == False: + time.sleep(1.0) + + except (KeyboardInterrupt, OSError): + print("Keyboard interruption in main thread... closing server.") + run_server = False + ui.close() + + +def show_msg(msg: str): + gr.Info(msg) + diff --git a/ui/tabs/extras_tab.py b/ui/tabs/extras_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..7686542649857d1d2b61722460066ea71f38d12d --- /dev/null +++ b/ui/tabs/extras_tab.py @@ -0,0 +1,184 @@ +import os +import gradio as gr +import shutil +import roop.utilities as util +import roop.util_ffmpeg as ffmpeg +import roop.globals + +frame_filters_map = { + "Colorize B/W Images (Deoldify Artistic)" : {"colorizer" : {"subtype": "deoldify_artistic"}}, + "Colorize B/W Images (Deoldify Stable)" : {"colorizer" : {"subtype": "deoldify_stable"}}, + "Background remove" : {"removebg" : {"subtype": ""}}, + "Filter Stylize" : {"filter_generic" : {"subtype" : "stylize" }}, + "Filter Detail Enhance" : {"filter_generic" : {"subtype" : "detailenhance" }}, + "Filter Pencil Sketch" : {"filter_generic" : {"subtype" : "pencil" }}, + "Filter Cartoon" : {"filter_generic" : {"subtype" : "cartoon" }}, + "Filter C64" : {"filter_generic" : {"subtype" : "C64" }} + } + +frame_upscalers_map = { + "ESRGAN x2" : {"upscale" : {"subtype": "esrganx2"}}, + "ESRGAN x4" : {"upscale" : {"subtype": "esrganx4"}}, + "LSDIR x4" : {"upscale" : {"subtype": "lsdirx4"}} +} + +def extras_tab(): + filternames = ["None"] + for f in frame_filters_map.keys(): + filternames.append(f) + upscalernames = ["None"] + for f in frame_upscalers_map.keys(): + upscalernames.append(f) + + with gr.Tab("๐ŸŽ‰ Extras"): + with gr.Row(): + files_to_process = gr.Files(label='File(s) to process', file_count="multiple", file_types=["image", "video"]) + with gr.Row(variant='panel'): + with gr.Accordion(label="Video/GIF", open=False): + with gr.Row(variant='panel'): + with gr.Column(): + gr.Markdown(""" + # Poor man's video editor + Re-encoding uses your configuration from the Settings Tab. + """) + with gr.Column(): + cut_start_time = gr.Slider(0, 1000000, value=0, label="Start Frame", step=1.0, interactive=True) + with gr.Column(): + cut_end_time = gr.Slider(1, 1000000, value=1, label="End Frame", step=1.0, interactive=True) + with gr.Column(): + extras_chk_encode = gr.Checkbox(label='Re-encode videos (necessary for videos with different codecs)', value=False) + start_cut_video = gr.Button("Cut video") + start_extract_frames = gr.Button("Extract frames") + start_join_videos = gr.Button("Join videos") + + with gr.Row(variant='panel'): + with gr.Column(): + gr.Markdown(""" + # Create video/gif from images + """) + with gr.Column(): + extras_fps = gr.Slider(minimum=0, maximum=120, value=30, label="Video FPS", step=1.0, interactive=True) + extras_images_folder = gr.Textbox(show_label=False, placeholder="/content/", interactive=True) + with gr.Column(): + extras_chk_creategif = gr.Checkbox(label='Create GIF from video', value=False) + extras_create_video=gr.Button("Create") + with gr.Row(variant='panel'): + with gr.Accordion(label="Full frame processing", open=True): + with gr.Row(variant='panel'): + filterselection = gr.Dropdown(filternames, value="None", label="Colorizer/FilterFX", interactive=True) + upscalerselection = gr.Dropdown(upscalernames, value="None", label="Enhancer", interactive=True) + with gr.Row(variant='panel'): + start_frame_process=gr.Button("Start processing") + + with gr.Row(): + gr.Button("๐Ÿ‘€ Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path)) + with gr.Row(): + extra_files_output = gr.Files(label='Resulting output files', file_count="multiple") + + start_cut_video.click(fn=on_cut_video, inputs=[files_to_process, cut_start_time, cut_end_time, extras_chk_encode], outputs=[extra_files_output]) + start_extract_frames.click(fn=on_extras_extract_frames, inputs=[files_to_process], outputs=[extra_files_output]) + start_join_videos.click(fn=on_join_videos, inputs=[files_to_process, extras_chk_encode], outputs=[extra_files_output]) + extras_create_video.click(fn=on_extras_create_video, inputs=[extras_images_folder, extras_fps, extras_chk_creategif], outputs=[extra_files_output]) + start_frame_process.click(fn=on_frame_process, inputs=[files_to_process, filterselection, upscalerselection], outputs=[extra_files_output]) + + +def on_cut_video(files, cut_start_frame, cut_end_frame, reencode): + if files is None: + return None + + resultfiles = [] + for tf in files: + f = tf.name + destfile = util.get_destfilename_from_path(f, roop.globals.output_path, '_cut') + ffmpeg.cut_video(f, destfile, cut_start_frame, cut_end_frame, reencode) + if os.path.isfile(destfile): + resultfiles.append(destfile) + else: + gr.Error('Cutting video failed!') + return resultfiles + + +def on_join_videos(files, chk_encode): + if files is None: + return None + + filenames = [] + for f in files: + filenames.append(f.name) + destfile = util.get_destfilename_from_path(filenames[0], roop.globals.output_path, '_join') + sorted_filenames = util.sort_filenames_ignore_path(filenames) + ffmpeg.join_videos(sorted_filenames, destfile, not chk_encode) + resultfiles = [] + if os.path.isfile(destfile): + resultfiles.append(destfile) + else: + gr.Error('Joining videos failed!') + return resultfiles + + + +def on_extras_create_video(images_path,fps, create_gif): + util.sort_rename_frames(os.path.dirname(images_path)) + destfilename = os.path.join(roop.globals.output_path, "img2video." + roop.globals.CFG.output_video_format) + ffmpeg.create_video('', destfilename, fps, images_path) + resultfiles = [] + if os.path.isfile(destfilename): + resultfiles.append(destfilename) + else: + return None + if create_gif: + gifname = util.get_destfilename_from_path(destfilename, './output', '.gif') + ffmpeg.create_gif_from_video(destfilename, gifname) + if os.path.isfile(destfilename): + resultfiles.append(gifname) + return resultfiles + + +def on_extras_extract_frames(files): + if files is None: + return None + + resultfiles = [] + for tf in files: + f = tf.name + resfolder = ffmpeg.extract_frames(f) + for file in os.listdir(resfolder): + outfile = os.path.join(resfolder, file) + if os.path.isfile(outfile): + resultfiles.append(outfile) + return resultfiles + + +def on_frame_process(files, filterselection, upscaleselection): + import pathlib + from roop.core import batch_process_with_options + from roop.ProcessEntry import ProcessEntry + from roop.ProcessOptions import ProcessOptions + from ui.main import prepare_environment + + + if files is None: + return None + + if roop.globals.CFG.clear_output: + shutil.rmtree(roop.globals.output_path) + prepare_environment() + list_files_process : list[ProcessEntry] = [] + + for tf in files: + list_files_process.append(ProcessEntry(tf.name, 0,0, 0)) + + processoroptions = {} + filter = next((x for x in frame_filters_map.keys() if x == filterselection), None) + if filter is not None: + processoroptions.update(frame_filters_map[filter]) + filter = next((x for x in frame_upscalers_map.keys() if x == upscaleselection), None) + if filter is not None: + processoroptions.update(frame_upscalers_map[filter]) + options = ProcessOptions(processoroptions, 0, 0, "all", 0, None, None, None, False) + batch_process_with_options(list_files_process, options, None) + outdir = pathlib.Path(roop.globals.output_path) + outfiles = [str(item) for item in outdir.rglob("*") if item.is_file()] + return outfiles + + diff --git a/ui/tabs/facemgr_tab.py b/ui/tabs/facemgr_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..05a5ac3f53c02703a8b14216b1fd051b57f87977 --- /dev/null +++ b/ui/tabs/facemgr_tab.py @@ -0,0 +1,187 @@ +import os +import shutil +import cv2 +import gradio as gr +import roop.utilities as util +import roop.globals +from roop.face_util import extract_face_images +from roop.capturer import get_video_frame, get_video_frame_total +from typing import List, Tuple, Optional +from roop.typing import Frame, Face, FaceSet + +selected_face_index = -1 +thumbs = [] +images = [] + + +def facemgr_tab() -> None: + with gr.Tab("๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ Face Management"): + with gr.Row(): + gr.Markdown(""" + # Create blending facesets + Add multiple reference images into a faceset file. + """) + with gr.Row(): + videoimagefst = gr.Image(label="Cut face from video frame", height=576, interactive=False, visible=True) + with gr.Row(): + frame_num_fst = gr.Slider(1, 1, value=1, label="Frame Number", info='0:00:00', step=1.0, interactive=False) + fb_cutfromframe = gr.Button("Use faces from this frame", variant='secondary', interactive=False) + with gr.Row(): + fb_facesetfile = gr.Files(label='Faceset', file_count='single', file_types=['.fsz'], interactive=True) + fb_files = gr.Files(label='Input Files', file_count="multiple", file_types=["image", "video"], interactive=True) + with gr.Row(): + with gr.Column(): + gr.Button("๐Ÿ‘€ Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path)) + with gr.Column(): + gr.Markdown(' ') + with gr.Row(): + faces = gr.Gallery(label="Faces in this Faceset", allow_preview=True, preview=True, height=128, object_fit="scale-down") + with gr.Row(): + fb_remove = gr.Button("Remove selected", variant='secondary') + fb_update = gr.Button("Create/Update Faceset file", variant='primary') + fb_clear = gr.Button("Clear all", variant='stop') + + fb_facesetfile.change(fn=on_faceset_changed, inputs=[fb_facesetfile], outputs=[faces]) + fb_files.change(fn=on_fb_files_changed, inputs=[fb_files], outputs=[faces, videoimagefst, frame_num_fst, fb_cutfromframe]) + fb_update.click(fn=on_update_clicked, outputs=[fb_facesetfile]) + fb_remove.click(fn=on_remove_clicked, outputs=[faces]) + fb_clear.click(fn=on_clear_clicked, outputs=[faces, fb_files, fb_facesetfile]) + fb_cutfromframe.click(fn=on_cutfromframe_clicked, inputs=[fb_files, frame_num_fst], outputs=[faces]) + frame_num_fst.release(fn=on_frame_num_fst_changed, inputs=[fb_files, frame_num_fst], outputs=[videoimagefst]) + faces.select(fn=on_face_selected) + + +def on_faceset_changed(faceset, progress=gr.Progress()) -> List[Frame]: + global thumbs, images + + if faceset is None: + return thumbs + + thumbs.clear() + filename = faceset.name + + if filename.lower().endswith('fsz'): + progress(0, desc="Retrieving faces from Faceset File", ) + unzipfolder = os.path.join(os.environ["TEMP"], 'faceset') + if os.path.isdir(unzipfolder): + shutil.rmtree(unzipfolder) + util.mkdir_with_umask(unzipfolder) + util.unzip(filename, unzipfolder) + for file in os.listdir(unzipfolder): + if file.endswith(".png"): + SELECTION_FACES_DATA = extract_face_images(os.path.join(unzipfolder,file), (False, 0), 0.5) + if len(SELECTION_FACES_DATA) < 1: + gr.Warning(f"No face detected in {file}!") + for f in SELECTION_FACES_DATA: + image = f[1] + images.append(image) + thumbs.append(util.convert_to_gradio(image)) + + return thumbs + + +def on_fb_files_changed(inputfiles, progress=gr.Progress()) -> Tuple[List[Frame], Optional[gr.Image], Optional[gr.Slider], Optional[gr.Button]]: + global thumbs, images, total_frames, current_video_fps + + if inputfiles is None or len(inputfiles) < 1: + return thumbs, None, None, None + + progress(0, desc="Retrieving faces from images", ) + slider = None + video_image = None + cut_button = None + for f in inputfiles: + source_path = f.name + if util.has_image_extension(source_path): + slider = gr.Slider(interactive=False) + video_image = gr.Image(interactive=False) + cut_button = gr.Button(interactive=False) + roop.globals.source_path = source_path + SELECTION_FACES_DATA = extract_face_images(roop.globals.source_path, (False, 0), 0.5) + for f in SELECTION_FACES_DATA: + image = f[1] + images.append(image) + thumbs.append(util.convert_to_gradio(image)) + elif util.is_video(source_path) or source_path.lower().endswith('gif'): + total_frames = get_video_frame_total(source_path) + current_video_fps = util.detect_fps(source_path) + cut_button = gr.Button(interactive=True) + video_image, slider = display_video_frame(source_path, 1, total_frames) + + return thumbs, video_image, slider, cut_button + + +def display_video_frame(filename: str, frame_num: int, total: int=0) -> Tuple[gr.Image, gr.Slider]: + global current_video_fps + + current_frame = get_video_frame(filename, frame_num) + if current_video_fps == 0: + current_video_fps = 1 + secs = (frame_num - 1) / current_video_fps + minutes = secs / 60 + secs = secs % 60 + hours = minutes / 60 + minutes = minutes % 60 + milliseconds = (secs - int(secs)) * 1000 + timeinfo = f"{int(hours):0>2}:{int(minutes):0>2}:{int(secs):0>2}.{int(milliseconds):0>3}" + if total > 0: + return gr.Image(value=util.convert_to_gradio(current_frame), interactive=True), gr.Slider(info=timeinfo, minimum=1, maximum=total, interactive=True) + return gr.Image(value=util.convert_to_gradio(current_frame), interactive=True), gr.Slider(info=timeinfo, interactive=True) + + +def on_face_selected(evt: gr.SelectData) -> None: + global selected_face_index + + if evt is not None: + selected_face_index = evt.index + +def on_frame_num_fst_changed(inputfiles: List[gr.Files], frame_num: int) -> Frame: + filename = inputfiles[0].name + video_image, _ = display_video_frame(filename, frame_num, 0) + return video_image + + +def on_cutfromframe_clicked(inputfiles: List[gr.Files], frame_num: int) -> List[Frame]: + global thumbs + + filename = inputfiles[0].name + SELECTION_FACES_DATA = extract_face_images(filename, (True, frame_num), 0.5) + for f in SELECTION_FACES_DATA: + image = f[1] + images.append(image) + thumbs.append(util.convert_to_gradio(image)) + return thumbs + + +def on_remove_clicked() -> List[Frame]: + global thumbs, images, selected_face_index + + if len(thumbs) > selected_face_index: + f = thumbs.pop(selected_face_index) + del f + f = images.pop(selected_face_index) + del f + return thumbs + +def on_clear_clicked() -> Tuple[List[Frame], None, None]: + global thumbs, images + + thumbs.clear() + images.clear() + return thumbs, None, None + + +def on_update_clicked() -> Optional[str]: + if len(images) < 1: + gr.Warning(f"No faces to create faceset from!") + return None + + imgnames = [] + for index,img in enumerate(images): + filename = os.path.join(roop.globals.output_path, f'{index}.png') + cv2.imwrite(filename, img) + imgnames.append(filename) + + finalzip = os.path.join(roop.globals.output_path, 'faceset.fsz') + util.zip(imgnames, finalzip) + return finalzip diff --git a/ui/tabs/faceswap_tab.py b/ui/tabs/faceswap_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..97ec1fab45ce21296014c9cc896c072494a582c2 --- /dev/null +++ b/ui/tabs/faceswap_tab.py @@ -0,0 +1,717 @@ +import os +import shutil +import pathlib +import gradio as gr +import roop.utilities as util +import roop.globals +import ui.globals +from roop.face_util import extract_face_images, create_blank_image +from roop.capturer import get_video_frame, get_video_frame_total, get_image_frame +from roop.ProcessEntry import ProcessEntry +from roop.ProcessOptions import ProcessOptions +from roop.FaceSet import FaceSet + +last_image = None + + +IS_INPUT = True +SELECTED_FACE_INDEX = 0 + +SELECTED_INPUT_FACE_INDEX = 0 +SELECTED_TARGET_FACE_INDEX = 0 + +input_faces = None +target_faces = None +face_selection = None +previewimage = None + +selected_preview_index = 0 + +is_processing = False + +list_files_process : list[ProcessEntry] = [] +no_face_choices = ["Use untouched original frame","Retry rotated", "Skip Frame", "Skip Frame if no similar face"] + +current_video_fps = 50 + +manual_masking = False + + +def faceswap_tab(): + global no_face_choices, previewimage + + with gr.Tab("๐ŸŽญ Face Swap"): + with gr.Row(variant='panel'): + with gr.Column(scale=2): + with gr.Row(): + with gr.Column(min_width=160): + input_faces = gr.Gallery(label="Input faces", allow_preview=False, preview=False, height=128, object_fit="scale-down", columns=8) + with gr.Accordion(label="Advanced Masking", open=False): + chk_showmaskoffsets = gr.Checkbox(label="Show mask overlay in preview", value=False, interactive=True) + mask_top = gr.Slider(0, 1.0, value=0, label="Offset Face Top", step=0.01, interactive=True) + mask_bottom = gr.Slider(0, 1.0, value=0, label="Offset Face Bottom", step=0.01, interactive=True) + mask_left = gr.Slider(0, 1.0, value=0, label="Offset Face Left", step=0.01, interactive=True) + mask_right = gr.Slider(0, 1.0, value=0, label="Offset Face Right", step=0.01, interactive=True) + mask_erosion = gr.Slider(1.0, 3.0, value=1.0, label="Erosion Iterations", step=1.00, interactive=True) + mask_blur = gr.Slider(10.0, 50.0, value=20.0, label="Blur size", step=1.00, interactive=True) + bt_toggle_masking = gr.Button("Toggle manual masking", variant='secondary', size='sm') + selected_mask_engine = gr.Dropdown(["None", "Clip2Seg", "DFL XSeg"], value="None", label="Face masking engine") + clip_text = gr.Textbox(label="List of objects to mask and restore back on fake face", value="cup,hands,hair,banana", interactive=False) + bt_preview_mask = gr.Button("๐Ÿ‘ฅ Show Mask Preview", variant='secondary') + bt_remove_selected_input_face = gr.Button("โŒ Remove selected", size='sm') + bt_clear_input_faces = gr.Button("๐Ÿ’ฅ Clear all", variant='stop', size='sm') + with gr.Column(min_width=160): + target_faces = gr.Gallery(label="Target faces", allow_preview=False, preview=False, height=128, object_fit="scale-down", columns=8) + bt_remove_selected_target_face = gr.Button("โŒ Remove selected", size='sm') + bt_add_local = gr.Button('Add local files from', size='sm') + local_folder = gr.Textbox(show_label=False, placeholder="/content/", interactive=True) + with gr.Row(variant='panel'): + bt_srcfiles = gr.Files(label='Source File(s)', file_count="multiple", file_types=["image", ".fsz"], elem_id='filelist', height=233) + bt_destfiles = gr.Files(label='Target File(s)', file_count="multiple", file_types=["image", "video"], elem_id='filelist', height=233) + with gr.Row(variant='panel'): + gr.Markdown('') + forced_fps = gr.Slider(minimum=0, maximum=120, value=0, label="Video FPS", info='Overrides detected fps if not 0', step=1.0, interactive=True, container=True) + + with gr.Column(scale=2): + previewimage = gr.Image(label="Preview Image", height=576, interactive=False, visible=True) + maskimage = gr.ImageEditor(label="Manual mask Image", sources=["clipboard"], transforms="", type="numpy", + brush=gr.Brush(color_mode="fixed", colors=["rgba(255, 255, 255, 1"]), interactive=True, visible=False) + with gr.Row(variant='panel'): + fake_preview = gr.Checkbox(label="Face swap frames", value=False) + bt_refresh_preview = gr.Button("๐Ÿ”„ Refresh", variant='secondary', size='sm') + bt_use_face_from_preview = gr.Button("Use Face from this Frame", variant='primary', size='sm') + with gr.Row(): + preview_frame_num = gr.Slider(1, 1, value=1, label="Frame Number", info='0:00:00', step=1.0, interactive=True) + with gr.Row(): + text_frame_clip = gr.Markdown('Processing frame range [0 - 0]') + set_frame_start = gr.Button("โฌ… Set as Start", size='sm') + set_frame_end = gr.Button("โžก Set as End", size='sm') + with gr.Row(visible=False) as dynamic_face_selection: + with gr.Column(scale=2): + face_selection = gr.Gallery(label="Detected faces", allow_preview=False, preview=False, height=256, object_fit="cover", columns=8) + with gr.Column(): + bt_faceselect = gr.Button("โ˜‘ Use selected face", size='sm') + bt_cancelfaceselect = gr.Button("Done", size='sm') + with gr.Column(): + gr.Markdown(' ') + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + selected_face_detection = gr.Dropdown(["First found", "All female", "All male", "All faces", "Selected face"], value="First found", label="Specify face selection for swapping") + with gr.Column(scale=1): + ui.globals.ui_selected_enhancer = gr.Dropdown(["None", "Codeformer", "DMDNet", "GFPGAN", "GPEN", "Restoreformer++"], value="None", label="Select post-processing") + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + max_face_distance = gr.Slider(0.01, 1.0, value=0.65, label="Max Face Similarity Threshold", info="0.0 = identical 1.0 = no similarity") + with gr.Column(scale=1): + num_swap_steps = gr.Slider(1, 5, value=1, step=1.0, label="Number of swapping steps", info="More steps can increase likeness") + with gr.Column(scale=2): + ui.globals.ui_blend_ratio = gr.Slider(0.0, 1.0, value=0.65, label="Original/Enhanced image blend ratio", info="Only used with active post-processing") + + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + video_swapping_method = gr.Dropdown(["Extract Frames to media","In-Memory processing"], value="In-Memory processing", label="Select video processing method", interactive=True) + no_face_action = gr.Dropdown(choices=no_face_choices, value=no_face_choices[0], label="Action on no face detected", interactive=True) + vr_mode = gr.Checkbox(label="VR Mode", value=False) + with gr.Column(scale=1): + with gr.Group(): + autorotate = gr.Checkbox(label="Auto rotate horizontal Faces", value=True) + roop.globals.skip_audio = gr.Checkbox(label="Skip audio", value=False) + roop.globals.keep_frames = gr.Checkbox(label="Keep Frames (relevant only when extracting frames)", value=False) + roop.globals.wait_after_extraction = gr.Checkbox(label="Wait for user key press before creating video ", value=False) + + + + with gr.Row(variant='panel'): + with gr.Column(): + bt_start = gr.Button("โ–ถ Start", variant='primary') + gr.Button("๐Ÿ‘€ Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path)) + with gr.Column(): + bt_stop = gr.Button("โน Stop", variant='secondary', interactive=False) + with gr.Column(scale=2): + gr.Markdown(' ') + with gr.Row(variant='panel'): + with gr.Column(): + resultfiles = gr.Files(label='Processed File(s)', interactive=False) + with gr.Column(): + resultimage = gr.Image(type='filepath', label='Final Image', interactive=False ) + resultvideo = gr.Video(label='Final Video', interactive=False, visible=False) + + previewinputs = [preview_frame_num, bt_destfiles, fake_preview, ui.globals.ui_selected_enhancer, selected_face_detection, + max_face_distance, ui.globals.ui_blend_ratio, selected_mask_engine, clip_text, no_face_action, vr_mode, autorotate, maskimage, chk_showmaskoffsets, num_swap_steps] + previewoutputs = [previewimage, maskimage, preview_frame_num] + input_faces.select(on_select_input_face, None, None).then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs) + bt_remove_selected_input_face.click(fn=remove_selected_input_face, outputs=[input_faces]) + bt_srcfiles.change(fn=on_srcfile_changed, show_progress='full', inputs=bt_srcfiles, outputs=[dynamic_face_selection, face_selection, input_faces]) + + mask_top.release(fn=on_mask_top_changed, inputs=[mask_top], show_progress='hidden') + mask_bottom.release(fn=on_mask_bottom_changed, inputs=[mask_bottom], show_progress='hidden') + mask_left.release(fn=on_mask_left_changed, inputs=[mask_left], show_progress='hidden') + mask_right.release(fn=on_mask_right_changed, inputs=[mask_right], show_progress='hidden') + mask_erosion.release(fn=on_mask_erosion_changed, inputs=[mask_erosion], show_progress='hidden') + mask_blur.release(fn=on_mask_blur_changed, inputs=[mask_blur], show_progress='hidden') + selected_mask_engine.change(fn=on_mask_engine_changed, inputs=[selected_mask_engine], outputs=[clip_text], show_progress='hidden') + + + target_faces.select(on_select_target_face, None, None) + bt_remove_selected_target_face.click(fn=remove_selected_target_face, outputs=[target_faces]) + + forced_fps.change(fn=on_fps_changed, inputs=[forced_fps], show_progress='hidden') + bt_destfiles.change(fn=on_destfiles_changed, inputs=[bt_destfiles], outputs=[preview_frame_num, text_frame_clip], show_progress='hidden').then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden') + bt_destfiles.select(fn=on_destfiles_selected, outputs=[preview_frame_num, text_frame_clip, forced_fps], show_progress='hidden').then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden') + bt_destfiles.clear(fn=on_clear_destfiles, outputs=[target_faces, selected_face_detection]) + resultfiles.select(fn=on_resultfiles_selected, inputs=[resultfiles], outputs=[resultimage, resultvideo]) + + face_selection.select(on_select_face, None, None) + bt_faceselect.click(fn=on_selected_face, outputs=[input_faces, target_faces, selected_face_detection]) + bt_cancelfaceselect.click(fn=on_end_face_selection, outputs=[dynamic_face_selection, face_selection]) + + bt_clear_input_faces.click(fn=on_clear_input_faces, outputs=[input_faces]) + + + bt_add_local.click(fn=on_add_local_folder, inputs=[local_folder], outputs=[bt_destfiles]) + bt_preview_mask.click(fn=on_preview_mask, inputs=[preview_frame_num, bt_destfiles, clip_text, selected_mask_engine], outputs=[previewimage]) + + start_event = bt_start.click(fn=start_swap, + inputs=[ui.globals.ui_selected_enhancer, selected_face_detection, roop.globals.keep_frames, roop.globals.wait_after_extraction, + roop.globals.skip_audio, max_face_distance, ui.globals.ui_blend_ratio, selected_mask_engine, clip_text,video_swapping_method, no_face_action, vr_mode, autorotate, num_swap_steps, maskimage], + outputs=[bt_start, bt_stop, resultfiles], show_progress='full') + after_swap_event = start_event.then(fn=on_resultfiles_finished, inputs=[resultfiles], outputs=[resultimage, resultvideo]) + + bt_stop.click(fn=stop_swap, cancels=[start_event, after_swap_event], outputs=[bt_start, bt_stop], queue=False) + + bt_refresh_preview.click(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs) + bt_toggle_masking.click(fn=on_toggle_masking, inputs=[previewimage, maskimage], outputs=[previewimage, maskimage]) + fake_preview.change(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs) + preview_frame_num.release(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden', ) + bt_use_face_from_preview.click(fn=on_use_face_from_selected, show_progress='full', inputs=[bt_destfiles, preview_frame_num], outputs=[dynamic_face_selection, face_selection, target_faces, selected_face_detection]) + set_frame_start.click(fn=on_set_frame, inputs=[set_frame_start, preview_frame_num], outputs=[text_frame_clip]) + set_frame_end.click(fn=on_set_frame, inputs=[set_frame_end, preview_frame_num], outputs=[text_frame_clip]) + + + +def on_mask_top_changed(mask_offset): + set_mask_offset(0, mask_offset) + +def on_mask_bottom_changed(mask_offset): + set_mask_offset(1, mask_offset) + +def on_mask_left_changed(mask_offset): + set_mask_offset(2, mask_offset) + +def on_mask_right_changed(mask_offset): + set_mask_offset(3, mask_offset) + +def on_mask_erosion_changed(mask_offset): + set_mask_offset(4, mask_offset) +def on_mask_blur_changed(mask_offset): + set_mask_offset(5, mask_offset) + + +def set_mask_offset(index, mask_offset): + global SELECTED_INPUT_FACE_INDEX + + if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX: + offs = roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets + offs[index] = mask_offset + if offs[0] + offs[1] > 0.99: + offs[0] = 0.99 + offs[1] = 0.0 + if offs[2] + offs[3] > 0.99: + offs[2] = 0.99 + offs[3] = 0.0 + roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets = offs + +def on_mask_engine_changed(mask_engine): + if mask_engine == "Clip2Seg": + return gr.Textbox(interactive=True) + return gr.Textbox(interactive=False) + + + +def on_add_local_folder(folder): + files = util.get_local_files_from_folder(folder) + if files is None: + gr.Warning("Empty folder or folder not found!") + return files + + +def on_srcfile_changed(srcfiles, progress=gr.Progress()): + global SELECTION_FACES_DATA, IS_INPUT, input_faces, face_selection, last_image + + IS_INPUT = True + + if srcfiles is None or len(srcfiles) < 1: + return gr.Column(visible=False), None, ui.globals.ui_input_thumbs + + thumbs = [] + for f in srcfiles: + source_path = f.name + if source_path.lower().endswith('fsz'): + progress(0, desc="Retrieving faces from Faceset File") + unzipfolder = os.path.join(os.environ["TEMP"], 'faceset') + if os.path.isdir(unzipfolder): + files = os.listdir(unzipfolder) + for file in files: + os.remove(os.path.join(unzipfolder, file)) + else: + os.makedirs(unzipfolder) + util.mkdir_with_umask(unzipfolder) + util.unzip(source_path, unzipfolder) + is_first = True + face_set = FaceSet() + for file in os.listdir(unzipfolder): + if file.endswith(".png"): + filename = os.path.join(unzipfolder,file) + progress(0, desc="Extracting faceset") + SELECTION_FACES_DATA = extract_face_images(filename, (False, 0)) + for f in SELECTION_FACES_DATA: + face = f[0] + face.mask_offsets = (0,0,0,0,1,20) + face_set.faces.append(face) + if is_first: + image = util.convert_to_gradio(f[1]) + ui.globals.ui_input_thumbs.append(image) + is_first = False + face_set.ref_images.append(get_image_frame(filename)) + if len(face_set.faces) > 0: + if len(face_set.faces) > 1: + face_set.AverageEmbeddings() + roop.globals.INPUT_FACESETS.append(face_set) + + elif util.has_image_extension(source_path): + progress(0, desc="Retrieving faces from image") + roop.globals.source_path = source_path + SELECTION_FACES_DATA = extract_face_images(roop.globals.source_path, (False, 0)) + progress(0.5, desc="Retrieving faces from image") + for f in SELECTION_FACES_DATA: + face_set = FaceSet() + face = f[0] + face.mask_offsets = (0,0,0,0,1,20) + face_set.faces.append(face) + image = util.convert_to_gradio(f[1]) + ui.globals.ui_input_thumbs.append(image) + roop.globals.INPUT_FACESETS.append(face_set) + + progress(1.0) + + # old style with selecting input faces commented out + # if len(thumbs) < 1: + # return gr.Column(visible=False), None, ui.globals.ui_input_thumbs + # return gr.Column(visible=True), thumbs, gr.Gallery(visible=True) + + return gr.Column(visible=False), None, ui.globals.ui_input_thumbs + + +def on_select_input_face(evt: gr.SelectData): + global SELECTED_INPUT_FACE_INDEX + + SELECTED_INPUT_FACE_INDEX = evt.index + + +def remove_selected_input_face(): + global SELECTED_INPUT_FACE_INDEX + + if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX: + f = roop.globals.INPUT_FACESETS.pop(SELECTED_INPUT_FACE_INDEX) + del f + if len(ui.globals.ui_input_thumbs) > SELECTED_INPUT_FACE_INDEX: + f = ui.globals.ui_input_thumbs.pop(SELECTED_INPUT_FACE_INDEX) + del f + + return ui.globals.ui_input_thumbs + +def on_select_target_face(evt: gr.SelectData): + global SELECTED_TARGET_FACE_INDEX + + SELECTED_TARGET_FACE_INDEX = evt.index + +def remove_selected_target_face(): + if len(roop.globals.TARGET_FACES) > SELECTED_TARGET_FACE_INDEX: + f = roop.globals.TARGET_FACES.pop(SELECTED_TARGET_FACE_INDEX) + del f + if len(ui.globals.ui_target_thumbs) > SELECTED_TARGET_FACE_INDEX: + f = ui.globals.ui_target_thumbs.pop(SELECTED_TARGET_FACE_INDEX) + del f + return ui.globals.ui_target_thumbs + + + + + +def on_use_face_from_selected(files, frame_num): + global IS_INPUT, SELECTION_FACES_DATA + + IS_INPUT = False + thumbs = [] + + roop.globals.target_path = files[selected_preview_index].name + if util.is_image(roop.globals.target_path) and not roop.globals.target_path.lower().endswith(('gif')): + SELECTION_FACES_DATA = extract_face_images(roop.globals.target_path, (False, 0)) + if len(SELECTION_FACES_DATA) > 0: + for f in SELECTION_FACES_DATA: + image = util.convert_to_gradio(f[1]) + thumbs.append(image) + else: + gr.Info('No faces detected!') + roop.globals.target_path = None + + elif util.is_video(roop.globals.target_path) or roop.globals.target_path.lower().endswith(('gif')): + selected_frame = frame_num + SELECTION_FACES_DATA = extract_face_images(roop.globals.target_path, (True, selected_frame)) + if len(SELECTION_FACES_DATA) > 0: + for f in SELECTION_FACES_DATA: + image = util.convert_to_gradio(f[1]) + thumbs.append(image) + else: + gr.Info('No faces detected!') + roop.globals.target_path = None + + if len(thumbs) == 1: + roop.globals.TARGET_FACES.append(SELECTION_FACES_DATA[0][0]) + ui.globals.ui_target_thumbs.append(thumbs[0]) + return gr.Row(visible=False), None, ui.globals.ui_target_thumbs, gr.Dropdown(value='Selected face') + + return gr.Row(visible=True), thumbs, gr.Gallery(visible=True), gr.Dropdown(visible=True) + + + +def on_select_face(evt: gr.SelectData): # SelectData is a subclass of EventData + global SELECTED_FACE_INDEX + SELECTED_FACE_INDEX = evt.index + + +def on_selected_face(): + global IS_INPUT, SELECTED_FACE_INDEX, SELECTION_FACES_DATA + + fd = SELECTION_FACES_DATA[SELECTED_FACE_INDEX] + image = util.convert_to_gradio(fd[1]) + if IS_INPUT: + face_set = FaceSet() + fd[0].mask_offsets = (0,0,0,0,1,20) + face_set.faces.append(fd[0]) + roop.globals.INPUT_FACESETS.append(face_set) + ui.globals.ui_input_thumbs.append(image) + return ui.globals.ui_input_thumbs, gr.Gallery(visible=True), gr.Dropdown(visible=True) + else: + roop.globals.TARGET_FACES.append(fd[0]) + ui.globals.ui_target_thumbs.append(image) + return gr.Gallery(visible=True), ui.globals.ui_target_thumbs, gr.Dropdown(value='Selected face') + +# bt_faceselect.click(fn=on_selected_face, outputs=[dynamic_face_selection, face_selection, input_faces, target_faces]) + +def on_end_face_selection(): + return gr.Column(visible=False), None + + +def on_preview_frame_changed(frame_num, files, fake_preview, enhancer, detection, face_distance, blend_ratio, + selected_mask_engine, clip_text, no_face_action, vr_mode, auto_rotate, maskimage, show_face_area, num_steps): + global SELECTED_INPUT_FACE_INDEX, manual_masking, current_video_fps + + from roop.core import live_swap, get_processing_plugins + + manual_masking = False + mask_offsets = (0,0,0,0) + if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX: + if not hasattr(roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0], 'mask_offsets'): + roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets = mask_offsets + mask_offsets = roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets + + timeinfo = '0:00:00' + if files is None or selected_preview_index >= len(files) or frame_num is None: + return None,None, gr.Slider(info=timeinfo) + + filename = files[selected_preview_index].name + if util.is_video(filename) or filename.lower().endswith('gif'): + current_frame = get_video_frame(filename, frame_num) + if current_video_fps == 0: + current_video_fps = 1 + secs = (frame_num - 1) / current_video_fps + minutes = secs / 60 + secs = secs % 60 + hours = minutes / 60 + minutes = minutes % 60 + milliseconds = (secs - int(secs)) * 1000 + timeinfo = f"{int(hours):0>2}:{int(minutes):0>2}:{int(secs):0>2}.{int(milliseconds):0>3}" + else: + current_frame = get_image_frame(filename) + if current_frame is None: + return None, None, gr.Slider(info=timeinfo) + + layers = None + if maskimage is not None: + layers = maskimage["layers"] + + if not fake_preview or len(roop.globals.INPUT_FACESETS) < 1: + return gr.Image(value=util.convert_to_gradio(current_frame), visible=True), gr.ImageEditor(visible=False), gr.Slider(info=timeinfo) + + roop.globals.face_swap_mode = translate_swap_mode(detection) + roop.globals.selected_enhancer = enhancer + roop.globals.distance_threshold = face_distance + roop.globals.blend_ratio = blend_ratio + roop.globals.no_face_action = index_of_no_face_action(no_face_action) + roop.globals.vr_mode = vr_mode + roop.globals.autorotate_faces = auto_rotate + + mask_engine = map_mask_engine(selected_mask_engine, clip_text) + + roop.globals.execution_threads = roop.globals.CFG.max_threads + mask = layers[0] if layers is not None else None + face_index = SELECTED_INPUT_FACE_INDEX + if len(roop.globals.INPUT_FACESETS) <= face_index: + face_index = 0 + + options = ProcessOptions(get_processing_plugins(mask_engine), roop.globals.distance_threshold, roop.globals.blend_ratio, + roop.globals.face_swap_mode, face_index, clip_text, maskimage, num_steps, show_face_area) + + current_frame = live_swap(current_frame, options) + if current_frame is None: + return gr.Image(visible=True), None, gr.Slider(info=timeinfo) + return gr.Image(value=util.convert_to_gradio(current_frame), visible=True), gr.ImageEditor(visible=False), gr.Slider(info=timeinfo) + +def map_mask_engine(selected_mask_engine, clip_text): + if selected_mask_engine == "Clip2Seg": + mask_engine = "mask_clip2seg" + if clip_text is None or len(clip_text) < 1: + mask_engine = None + elif selected_mask_engine == "DFL XSeg": + mask_engine = "mask_xseg" + else: + mask_engine = None + return mask_engine + + + +def on_toggle_masking(previewimage, mask): + global manual_masking + + manual_masking = not manual_masking + if manual_masking: + layers = mask["layers"] + if len(layers) == 1: + layers = [create_blank_image(previewimage.shape[1],previewimage.shape[0])] + return gr.Image(visible=False), gr.ImageEditor(value={"background": previewimage, "layers": layers, "composite": None}, visible=True) + return gr.Image(visible=True), gr.ImageEditor(visible=False) + +def gen_processing_text(start, end): + return f'Processing frame range [{start} - {end}]' + +def on_set_frame(sender:str, frame_num): + global selected_preview_index, list_files_process + + idx = selected_preview_index + if list_files_process[idx].endframe == 0: + return gen_processing_text(0,0) + + start = list_files_process[idx].startframe + end = list_files_process[idx].endframe + if sender.lower().endswith('start'): + list_files_process[idx].startframe = min(frame_num, end) + else: + list_files_process[idx].endframe = max(frame_num, start) + + return gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe) + + + +def on_preview_mask(frame_num, files, clip_text, mask_engine): + from roop.core import live_swap, get_processing_plugins + global is_processing + + if is_processing or files is None or selected_preview_index >= len(files) or clip_text is None or frame_num is None: + return None + + filename = files[selected_preview_index].name + if util.is_video(filename) or filename.lower().endswith('gif'): + current_frame = get_video_frame(filename, frame_num + ) + else: + current_frame = get_image_frame(filename) + if current_frame is None or mask_engine is None: + return None + if mask_engine == "Clip2Seg": + mask_engine = "mask_clip2seg" + if clip_text is None or len(clip_text) < 1: + mask_engine = None + elif mask_engine == "DFL XSeg": + mask_engine = "mask_xseg" + options = ProcessOptions(get_processing_plugins(mask_engine), roop.globals.distance_threshold, roop.globals.blend_ratio, + "all", 0, clip_text, None, 0, False, True) + + current_frame = live_swap(current_frame, options) + return util.convert_to_gradio(current_frame) + + + +def on_clear_input_faces(): + ui.globals.ui_input_thumbs.clear() + roop.globals.INPUT_FACESETS.clear() + return ui.globals.ui_input_thumbs + +def on_clear_destfiles(): + roop.globals.TARGET_FACES.clear() + ui.globals.ui_target_thumbs.clear() + return ui.globals.ui_target_thumbs, gr.Dropdown(value="First found") + + +def index_of_no_face_action(dropdown_text): + global no_face_choices + + return no_face_choices.index(dropdown_text) + +def translate_swap_mode(dropdown_text): + if dropdown_text == "Selected face": + return "selected" + elif dropdown_text == "First found": + return "first" + elif dropdown_text == "All female": + return "all_female" + elif dropdown_text == "All male": + return "all_male" + + return "all" + + + +def start_swap( enhancer, detection, keep_frames, wait_after_extraction, skip_audio, face_distance, blend_ratio, + selected_mask_engine, clip_text, processing_method, no_face_action, vr_mode, autorotate, num_swap_steps, imagemask, progress=gr.Progress()): + from ui.main import prepare_environment + from roop.core import batch_process_regular + global is_processing, list_files_process + + if list_files_process is None or len(list_files_process) <= 0: + return gr.Button(variant="primary"), None, None + + if roop.globals.CFG.clear_output: + shutil.rmtree(roop.globals.output_path) + + if not util.is_installed("ffmpeg"): + msg = "ffmpeg is not installed! No video processing possible." + gr.Warning(msg) + + prepare_environment() + + roop.globals.selected_enhancer = enhancer + roop.globals.target_path = None + roop.globals.distance_threshold = face_distance + roop.globals.blend_ratio = blend_ratio + roop.globals.keep_frames = keep_frames + roop.globals.wait_after_extraction = wait_after_extraction + roop.globals.skip_audio = skip_audio + roop.globals.face_swap_mode = translate_swap_mode(detection) + roop.globals.no_face_action = index_of_no_face_action(no_face_action) + roop.globals.vr_mode = vr_mode + roop.globals.autorotate_faces = autorotate + mask_engine = map_mask_engine(selected_mask_engine, clip_text) + + if roop.globals.face_swap_mode == 'selected': + if len(roop.globals.TARGET_FACES) < 1: + gr.Error('No Target Face selected!') + return gr.Button(variant="primary"), None, None + + is_processing = True + yield gr.Button(variant="secondary", interactive=False), gr.Button(variant="primary", interactive=True), None + roop.globals.execution_threads = roop.globals.CFG.max_threads + roop.globals.video_encoder = roop.globals.CFG.output_video_codec + roop.globals.video_quality = roop.globals.CFG.video_quality + roop.globals.max_memory = roop.globals.CFG.memory_limit if roop.globals.CFG.memory_limit > 0 else None + + batch_process_regular(list_files_process, mask_engine, clip_text, processing_method == "In-Memory processing", imagemask, num_swap_steps, progress, SELECTED_INPUT_FACE_INDEX) + is_processing = False + outdir = pathlib.Path(roop.globals.output_path) + outfiles = [str(item) for item in outdir.rglob("*") if item.is_file()] + if len(outfiles) > 0: + yield gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),gr.Files(value=outfiles) + else: + yield gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),None + + +def stop_swap(): + roop.globals.processing = False + gr.Info('Aborting processing - please wait for the remaining threads to be stopped') + return gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),None + + +def on_fps_changed(fps): + global selected_preview_index, list_files_process + + if len(list_files_process) < 1 or list_files_process[selected_preview_index].endframe < 1: + return + list_files_process[selected_preview_index].fps = fps + + +def on_destfiles_changed(destfiles): + global selected_preview_index, list_files_process, current_video_fps + + if destfiles is None or len(destfiles) < 1: + list_files_process.clear() + return gr.Slider(value=1, maximum=1, info='0:00:00'), '' + + for f in destfiles: + list_files_process.append(ProcessEntry(f.name, 0,0, 0)) + + selected_preview_index = 0 + idx = selected_preview_index + + filename = list_files_process[idx].filename + + if util.is_video(filename) or filename.lower().endswith('gif'): + total_frames = get_video_frame_total(filename) + current_video_fps = util.detect_fps(filename) + else: + total_frames = 1 + list_files_process[idx].endframe = total_frames + if total_frames > 1: + return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe) + return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), '' + + + + +def on_destfiles_selected(evt: gr.SelectData): + global selected_preview_index, list_files_process, current_video_fps + + if evt is not None: + selected_preview_index = evt.index + idx = selected_preview_index + filename = list_files_process[idx].filename + fps = list_files_process[idx].fps + if util.is_video(filename) or filename.lower().endswith('gif'): + total_frames = get_video_frame_total(filename) + current_video_fps = util.detect_fps(filename) + if list_files_process[idx].endframe == 0: + list_files_process[idx].endframe = total_frames + else: + total_frames = 1 + + if total_frames > 1: + return gr.Slider(value=list_files_process[idx].startframe, maximum=total_frames, info='0:00:00'), gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe), fps + return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), gen_processing_text(0,0), fps + + + +def on_resultfiles_selected(evt: gr.SelectData, files): + selected_index = evt.index + filename = files[selected_index].name + return display_output(filename) + +def on_resultfiles_finished(files): + selected_index = 0 + if files is None or len(files) < 1: + return None, None + + filename = files[selected_index].name + return display_output(filename) + + +def display_output(filename): + if util.is_video(filename) and roop.globals.CFG.output_show_video: + return gr.Image(visible=False), gr.Video(visible=True, value=filename) + else: + if util.is_video(filename) or filename.lower().endswith('gif'): + current_frame = get_video_frame(filename) + else: + current_frame = get_image_frame(filename) + return gr.Image(visible=True, value=util.convert_to_gradio(current_frame)), gr.Video(visible=False) diff --git a/ui/tabs/livecam_tab.py b/ui/tabs/livecam_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b5a228f8a324291be072790d22828350109b12 --- /dev/null +++ b/ui/tabs/livecam_tab.py @@ -0,0 +1,54 @@ +import gradio as gr +import roop.globals +import ui.globals + + +camera_frame = None + +def livecam_tab(): + with gr.Tab("๐ŸŽฅ Live Cam"): + with gr.Row(variant='panel'): + gr.Markdown(""" + This feature will allow you to use your physical webcam and apply the selected faces to the stream. + You can also forward the stream to a virtual camera, which can be used in video calls or streaming software.
+ Supported are: v4l2loopback (linux), OBS Virtual Camera (macOS/Windows) and unitycapture (Windows).
+ **Please note:** to change the face or any other settings you need to stop and restart a running live cam. + """) + + with gr.Row(variant='panel'): + with gr.Column(): + bt_start = gr.Button("โ–ถ Start", variant='primary') + with gr.Column(): + bt_stop = gr.Button("โน Stop", variant='secondary', interactive=False) + with gr.Column(): + camera_num = gr.Slider(0, 8, value=0, label="Camera Number", step=1.0, interactive=True) + cb_obs = gr.Checkbox(label="Forward stream to virtual camera", interactive=True) + with gr.Column(): + dd_reso = gr.Dropdown(choices=["640x480","1280x720", "1920x1080"], value="1280x720", label="Fake Camera Resolution", interactive=True) + + with gr.Row(): + fake_cam_image = gr.Image(label='Fake Camera Output', interactive=False) + + start_event = bt_start.click(fn=start_cam, inputs=[cb_obs, camera_num, dd_reso, ui.globals.ui_selected_enhancer, ui.globals.ui_blend_ratio],outputs=[bt_start, bt_stop,fake_cam_image]) + bt_stop.click(fn=stop_swap, cancels=[start_event], outputs=[bt_start, bt_stop], queue=False) + + +def start_cam(stream_to_obs, cam, reso, enhancer, blend_ratio): + from roop.virtualcam import start_virtual_cam + from roop.utilities import convert_to_gradio + + start_virtual_cam(stream_to_obs, cam, reso) + roop.globals.selected_enhancer = enhancer + roop.globals.blend_ratio = blend_ratio + while True: + yield gr.Button(interactive=False), gr.Button(interactive=True), convert_to_gradio(ui.globals.ui_camera_frame) + + +def stop_swap(): + from roop.virtualcam import stop_virtual_cam + stop_virtual_cam() + return gr.Button(interactive=True), gr.Button(interactive=False) + + + + diff --git a/ui/tabs/settings_tab.py b/ui/tabs/settings_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b34e91ac3090946afa7ec1bd2dd6ef29767f54 --- /dev/null +++ b/ui/tabs/settings_tab.py @@ -0,0 +1,129 @@ +import shutil +import os +import gradio as gr +import roop.globals +import ui.globals + +available_themes = ["Default", "gradio/glass", "gradio/monochrome", "gradio/seafoam", "gradio/soft", "gstaff/xkcd", "freddyaboulton/dracula_revamped", "ysharma/steampunk"] +image_formats = ['jpg','png', 'webp'] +video_formats = ['avi','mkv', 'mp4', 'webm'] +video_codecs = ['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc'] +providerlist = None + +settings_controls = [] + +def settings_tab(): + from roop.core import suggest_execution_providers + global providerlist + + providerlist = suggest_execution_providers() + with gr.Tab("โš™ Settings"): + with gr.Row(): + with gr.Column(): + themes = gr.Dropdown(available_themes, label="Theme", info="Change needs complete restart", value=roop.globals.CFG.selected_theme) + with gr.Column(): + settings_controls.append(gr.Checkbox(label="Public Server", value=roop.globals.CFG.server_share, elem_id='server_share', interactive=True)) + settings_controls.append(gr.Checkbox(label='Clear output folder before each run', value=roop.globals.CFG.clear_output, elem_id='clear_output', interactive=True)) + output_template = gr.Textbox(label="Filename Output Template", info="(file extension is added automatically)", lines=1, placeholder='{file}_{time}', value=roop.globals.CFG.output_template) + with gr.Column(): + input_server_name = gr.Textbox(label="Server Name", lines=1, info="Leave blank to run locally", value=roop.globals.CFG.server_name) + with gr.Column(): + input_server_port = gr.Number(label="Server Port", precision=0, info="Leave at 0 to use default", value=roop.globals.CFG.server_port) + with gr.Row(): + with gr.Column(): + settings_controls.append(gr.Dropdown(providerlist, label="Provider", value=roop.globals.CFG.provider, elem_id='provider', interactive=True)) + chk_det_size = gr.Checkbox(label="Use default Det-Size", value=True, elem_id='default_det_size', interactive=True) + settings_controls.append(gr.Checkbox(label="Force CPU for Face Analyser", value=roop.globals.CFG.force_cpu, elem_id='force_cpu', interactive=True)) + max_threads = gr.Slider(1, 32, value=roop.globals.CFG.max_threads, label="Max. Number of Threads", info='default: 3', step=1.0, interactive=True) + with gr.Column(): + memory_limit = gr.Slider(0, 128, value=roop.globals.CFG.memory_limit, label="Max. Memory to use (Gb)", info='0 meaning no limit', step=1.0, interactive=True) + settings_controls.append(gr.Dropdown(image_formats, label="Image Output Format", info='default: png', value=roop.globals.CFG.output_image_format, elem_id='output_image_format', interactive=True)) + with gr.Column(): + settings_controls.append(gr.Dropdown(video_codecs, label="Video Codec", info='default: libx264', value=roop.globals.CFG.output_video_codec, elem_id='output_video_codec', interactive=True)) + settings_controls.append(gr.Dropdown(video_formats, label="Video Output Format", info='default: mp4', value=roop.globals.CFG.output_video_format, elem_id='output_video_format', interactive=True)) + video_quality = gr.Slider(0, 100, value=roop.globals.CFG.video_quality, label="Video Quality (crf)", info='default: 14', step=1.0, interactive=True) + with gr.Column(): + with gr.Group(): + settings_controls.append(gr.Checkbox(label='Use OS temp folder', value=roop.globals.CFG.use_os_temp_folder, elem_id='use_os_temp_folder', interactive=True)) + settings_controls.append(gr.Checkbox(label='Show video in browser (re-encodes output)', value=roop.globals.CFG.output_show_video, elem_id='output_show_video', interactive=True)) + button_apply_restart = gr.Button("Restart Server", variant='primary') + button_clean_temp = gr.Button("Clean temp folder") + button_apply_settings = gr.Button("Apply Settings") + + chk_det_size.select(fn=on_option_changed) + + # Settings + for s in settings_controls: + s.select(fn=on_settings_changed) + max_threads.input(fn=lambda a,b='max_threads':on_settings_changed_misc(a,b), inputs=[max_threads]) + memory_limit.input(fn=lambda a,b='memory_limit':on_settings_changed_misc(a,b), inputs=[memory_limit]) + video_quality.input(fn=lambda a,b='video_quality':on_settings_changed_misc(a,b), inputs=[video_quality]) + + # button_clean_temp.click(fn=clean_temp, outputs=[bt_srcfiles, input_faces, target_faces, bt_destfiles]) + button_clean_temp.click(fn=clean_temp) + button_apply_settings.click(apply_settings, inputs=[themes, input_server_name, input_server_port, output_template]) + button_apply_restart.click(restart) + + +def on_option_changed(evt: gr.SelectData): + attribname = evt.target.elem_id + if isinstance(evt.target, gr.Checkbox): + if hasattr(roop.globals, attribname): + setattr(roop.globals, attribname, evt.selected) + return + elif isinstance(evt.target, gr.Dropdown): + if hasattr(roop.globals, attribname): + setattr(roop.globals, attribname, evt.value) + return + raise gr.Error(f'Unhandled Setting for {evt.target}') + + +def on_settings_changed_misc(new_val, attribname): + if hasattr(roop.globals.CFG, attribname): + setattr(roop.globals.CFG, attribname, new_val) + else: + print("Didn't find attrib!") + + + +def on_settings_changed(evt: gr.SelectData): + attribname = evt.target.elem_id + if isinstance(evt.target, gr.Checkbox): + if hasattr(roop.globals.CFG, attribname): + setattr(roop.globals.CFG, attribname, evt.selected) + return + elif isinstance(evt.target, gr.Dropdown): + if hasattr(roop.globals.CFG, attribname): + setattr(roop.globals.CFG, attribname, evt.value) + return + + raise gr.Error(f'Unhandled Setting for {evt.target}') + +def clean_temp(): + from ui.main import prepare_environment + + if not roop.globals.CFG.use_os_temp_folder: + shutil.rmtree(os.environ["TEMP"]) + prepare_environment() + + ui.globals.ui_input_thumbs.clear() + roop.globals.INPUT_FACESETS.clear() + roop.globals.TARGET_FACES.clear() + ui.globals.ui_target_thumbs = [] + gr.Info('Temp Files removed') + return None,None,None,None + + +def apply_settings(themes, input_server_name, input_server_port, output_template): + from ui.main import show_msg + + roop.globals.CFG.selected_theme = themes + roop.globals.CFG.server_name = input_server_name + roop.globals.CFG.server_port = input_server_port + roop.globals.CFG.output_template = output_template + roop.globals.CFG.save() + show_msg('Settings saved') + + +def restart(): + ui.globals.ui_restart_server = True