Hemanth-thunder commited on
Commit
c0551d3
·
1 Parent(s): 31a0704

End of training

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +15 -0
  2. autotrain-advanced/.dockerignore +9 -0
  3. autotrain-advanced/.github/workflows/build_documentation.yml +19 -0
  4. autotrain-advanced/.github/workflows/build_pr_documentation.yml +17 -0
  5. autotrain-advanced/.github/workflows/code_quality.yml +30 -0
  6. autotrain-advanced/.github/workflows/delete_doc_comment.yml +13 -0
  7. autotrain-advanced/.github/workflows/delete_doc_comment_trigger.yml +12 -0
  8. autotrain-advanced/.github/workflows/tests.yml +30 -0
  9. autotrain-advanced/.github/workflows/upload_pr_documentation.yml +16 -0
  10. autotrain-advanced/.gitignore +138 -0
  11. autotrain-advanced/Dockerfile +65 -0
  12. autotrain-advanced/LICENSE +202 -0
  13. autotrain-advanced/Makefile +28 -0
  14. autotrain-advanced/README.md +13 -0
  15. autotrain-advanced/docs/source/_toctree.yml +28 -0
  16. autotrain-advanced/docs/source/cost.mdx +17 -0
  17. autotrain-advanced/docs/source/dreambooth.mdx +18 -0
  18. autotrain-advanced/docs/source/getting_started.mdx +29 -0
  19. autotrain-advanced/docs/source/image_classification.mdx +40 -0
  20. autotrain-advanced/docs/source/index.mdx +34 -0
  21. autotrain-advanced/docs/source/llm_finetuning.mdx +43 -0
  22. autotrain-advanced/docs/source/model_choice.mdx +24 -0
  23. autotrain-advanced/docs/source/param_choice.mdx +25 -0
  24. autotrain-advanced/docs/source/support.mdx +12 -0
  25. autotrain-advanced/docs/source/text_classification.mdx +60 -0
  26. autotrain-advanced/examples/text_classification_binary.py +77 -0
  27. autotrain-advanced/examples/text_classification_multiclass.py +77 -0
  28. autotrain-advanced/requirements.txt +31 -0
  29. autotrain-advanced/setup.cfg +24 -0
  30. autotrain-advanced/setup.py +71 -0
  31. autotrain-advanced/src/autotrain/__init__.py +24 -0
  32. autotrain-advanced/src/autotrain/app.py +965 -0
  33. autotrain-advanced/src/autotrain/cli/__init__.py +13 -0
  34. autotrain-advanced/src/autotrain/cli/accelerated_autotrain.py +0 -0
  35. autotrain-advanced/src/autotrain/cli/autotrain.py +40 -0
  36. autotrain-advanced/src/autotrain/cli/run_app.py +55 -0
  37. autotrain-advanced/src/autotrain/cli/run_dreambooth.py +469 -0
  38. autotrain-advanced/src/autotrain/cli/run_llm.py +489 -0
  39. autotrain-advanced/src/autotrain/cli/run_setup.py +61 -0
  40. autotrain-advanced/src/autotrain/config.py +12 -0
  41. autotrain-advanced/src/autotrain/dataset.py +344 -0
  42. autotrain-advanced/src/autotrain/dreambooth_app.py +485 -0
  43. autotrain-advanced/src/autotrain/help.py +28 -0
  44. autotrain-advanced/src/autotrain/infer/__init__.py +0 -0
  45. autotrain-advanced/src/autotrain/infer/text_generation.py +50 -0
  46. autotrain-advanced/src/autotrain/languages.py +19 -0
  47. autotrain-advanced/src/autotrain/params.py +512 -0
  48. autotrain-advanced/src/autotrain/preprocessor/__init__.py +0 -0
  49. autotrain-advanced/src/autotrain/preprocessor/dreambooth.py +62 -0
  50. autotrain-advanced/src/autotrain/preprocessor/tabular.py +99 -0
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ base_model: stabilityai/stable-diffusion-xl-base-1.0
4
+ instance_prompt: a photo of sks dog
5
+ tags:
6
+ - text-to-image
7
+ - diffusers
8
+ - autotrain
9
+ inference: true
10
+ ---
11
+
12
+ # DreamBooth trained by AutoTrain
13
+
14
+ Test enoder was not trained.
15
+
autotrain-advanced/.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ build/
2
+ dist/
3
+ logs/
4
+ output/
5
+ output2/
6
+ test/
7
+ test.py
8
+ .DS_Store
9
+ .vscode/
autotrain-advanced/.github/workflows/build_documentation.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build documentation
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - doc-builder*
8
+ - v*-release
9
+
10
+ jobs:
11
+ build:
12
+ uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
13
+ with:
14
+ commit_sha: ${{ github.sha }}
15
+ package: autotrain-advanced
16
+ package_name: autotrain
17
+ secrets:
18
+ token: ${{ secrets.HUGGINGFACE_PUSH }}
19
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
autotrain-advanced/.github/workflows/build_pr_documentation.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build PR Documentation
2
+
3
+ on:
4
+ pull_request:
5
+
6
+ concurrency:
7
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
8
+ cancel-in-progress: true
9
+
10
+ jobs:
11
+ build:
12
+ uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
13
+ with:
14
+ commit_sha: ${{ github.event.pull_request.head.sha }}
15
+ pr_number: ${{ github.event.number }}
16
+ package: autotrain-advanced
17
+ package_name: autotrain
autotrain-advanced/.github/workflows/code_quality.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Code quality
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ release:
11
+ types:
12
+ - created
13
+
14
+ jobs:
15
+ check_code_quality:
16
+ name: Check code quality
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+ - name: Set up Python 3.9
21
+ uses: actions/setup-python@v2
22
+ with:
23
+ python-version: 3.9
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip
27
+ python -m pip install flake8 black isort
28
+ - name: Make quality
29
+ run: |
30
+ make quality
autotrain-advanced/.github/workflows/delete_doc_comment.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Delete doc comment
2
+
3
+ on:
4
+ workflow_run:
5
+ workflows: ["Delete doc comment trigger"]
6
+ types:
7
+ - completed
8
+
9
+ jobs:
10
+ delete:
11
+ uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
12
+ secrets:
13
+ comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
autotrain-advanced/.github/workflows/delete_doc_comment_trigger.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Delete doc comment trigger
2
+
3
+ on:
4
+ pull_request:
5
+ types: [ closed ]
6
+
7
+
8
+ jobs:
9
+ delete:
10
+ uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
11
+ with:
12
+ pr_number: ${{ github.event.number }}
autotrain-advanced/.github/workflows/tests.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tests
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ release:
11
+ types:
12
+ - created
13
+
14
+ jobs:
15
+ tests:
16
+ name: Run unit tests
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+ - name: Set up Python 3.9
21
+ uses: actions/setup-python@v2
22
+ with:
23
+ python-version: 3.9
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip
27
+ python -m pip install .[dev]
28
+ - name: Make test
29
+ run: |
30
+ make test
autotrain-advanced/.github/workflows/upload_pr_documentation.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Upload PR Documentation
2
+
3
+ on:
4
+ workflow_run:
5
+ workflows: ["Build PR Documentation"]
6
+ types:
7
+ - completed
8
+
9
+ jobs:
10
+ build:
11
+ uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12
+ with:
13
+ package_name: autotrain
14
+ secrets:
15
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16
+ comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
autotrain-advanced/.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Local stuff
2
+ .DS_Store
3
+ .vscode/
4
+ test/
5
+ test.py
6
+ output/
7
+ output2/
8
+ logs/
9
+
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ pip-wheel-metadata/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104
+ __pypackages__/
105
+
106
+ # Celery stuff
107
+ celerybeat-schedule
108
+ celerybeat.pid
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+
122
+ # Spyder project settings
123
+ .spyderproject
124
+ .spyproject
125
+
126
+ # Rope project settings
127
+ .ropeproject
128
+
129
+ # mkdocs documentation
130
+ /site
131
+
132
+ # mypy
133
+ .mypy_cache/
134
+ .dmypy.json
135
+ dmypy.json
136
+
137
+ # Pyre type checker
138
+ .pyre/
autotrain-advanced/Dockerfile ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive \
4
+ TZ=UTC
5
+
6
+ ENV PATH="${HOME}/miniconda3/bin:${PATH}"
7
+ ARG PATH="${HOME}/miniconda3/bin:${PATH}"
8
+
9
+ RUN mkdir -p /tmp/model
10
+ RUN chown -R 1000:1000 /tmp/model
11
+ RUN mkdir -p /tmp/data
12
+ RUN chown -R 1000:1000 /tmp/data
13
+
14
+ RUN apt-get update && \
15
+ apt-get upgrade -y && \
16
+ apt-get install -y \
17
+ build-essential \
18
+ cmake \
19
+ curl \
20
+ ca-certificates \
21
+ gcc \
22
+ git \
23
+ locales \
24
+ net-tools \
25
+ wget \
26
+ libpq-dev \
27
+ libsndfile1-dev \
28
+ git \
29
+ git-lfs \
30
+ libgl1 \
31
+ && rm -rf /var/lib/apt/lists/*
32
+
33
+
34
+ RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
35
+ git lfs install
36
+
37
+ WORKDIR /app
38
+ RUN mkdir -p /app/.cache
39
+ ENV HF_HOME="/app/.cache"
40
+ RUN chown -R 1000:1000 /app
41
+ USER 1000
42
+ ENV HOME=/app
43
+
44
+ ENV PYTHONPATH=$HOME/app \
45
+ PYTHONUNBUFFERED=1 \
46
+ GRADIO_ALLOW_FLAGGING=never \
47
+ GRADIO_NUM_PORTS=1 \
48
+ GRADIO_SERVER_NAME=0.0.0.0 \
49
+ SYSTEM=spaces
50
+
51
+
52
+ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
53
+ && sh Miniconda3-latest-Linux-x86_64.sh -b -p /app/miniconda \
54
+ && rm -f Miniconda3-latest-Linux-x86_64.sh
55
+ ENV PATH /app/miniconda/bin:$PATH
56
+
57
+ RUN conda create -p /app/env -y python=3.9
58
+
59
+ SHELL ["conda", "run","--no-capture-output", "-p","/app/env", "/bin/bash", "-c"]
60
+
61
+ RUN conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
62
+ RUN pip install git+https://github.com/huggingface/peft.git
63
+ COPY --chown=1000:1000 . /app/
64
+
65
+ RUN pip install -e .
autotrain-advanced/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
autotrain-advanced/Makefile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: quality style test
2
+
3
+ # Check that source code meets quality standards
4
+
5
+ quality:
6
+ black --check --line-length 119 --target-version py38 .
7
+ isort --check-only .
8
+ flake8 --max-line-length 119
9
+
10
+ # Format source code automatically
11
+
12
+ style:
13
+ black --line-length 119 --target-version py38 .
14
+ isort .
15
+
16
+ test:
17
+ pytest -sv ./src/
18
+
19
+ docker:
20
+ docker build -t autotrain-advanced:latest .
21
+ docker tag autotrain-advanced:latest huggingface/autotrain-advanced:latest
22
+ docker push huggingface/autotrain-advanced:latest
23
+
24
+ pip:
25
+ rm -rf build/
26
+ rm -rf dist/
27
+ python setup.py sdist bdist_wheel
28
+ twine upload dist/* --verbose
autotrain-advanced/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🤗 AutoTrain Advanced
2
+
3
+ AutoTrain Advanced: faster and easier training and deployments of state-of-the-art machine learning models
4
+
5
+ ## Installation
6
+
7
+ You can Install AutoTrain-Advanced python package via PIP. Please note you will need python >= 3.8 for AutoTrain Advanced to work properly.
8
+
9
+ pip install autotrain-advanced
10
+
11
+ Please make sure that you have git lfs installed. Check out the instructions here: https://github.com/git-lfs/git-lfs/wiki/Installation
12
+
13
+ ## Coming Soon!
autotrain-advanced/docs/source/_toctree.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - sections:
2
+ - local: index
3
+ title: 🤗 AutoTrain
4
+ - local: getting_started
5
+ title: Installation
6
+ - local: cost
7
+ title: How much does it cost?
8
+ - local: support
9
+ title: Get help and support
10
+ title: Get started
11
+ - sections:
12
+ - local: model_choice
13
+ title: Model Selection
14
+ - local: param_choice
15
+ title: Parameter Selection
16
+ title: Selecting Models and Parameters
17
+ - sections:
18
+ - local: text_classification
19
+ title: Text Classification
20
+ - local: llm_finetuning
21
+ title: LLM Finetuning
22
+ title: Text Tasks
23
+ - sections:
24
+ - local: image_classification
25
+ title: Image Classification
26
+ - local: dreambooth
27
+ title: DreamBooth
28
+ title: Image Tasks
autotrain-advanced/docs/source/cost.mdx ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How much does it cost?
2
+
3
+ AutoTrain provides you with best models which are deployable with just a few clicks.
4
+ Unlike other services, we don't own your models. Once the training is done, you can download them and use them anywhere you want.
5
+
6
+ Before you start training, you can see the estimated cost of training.
7
+
8
+ Free tier is available for everyone. For a limited number of samples, you can train your models for free!
9
+ If your dataset is larger, you will be presented with the estimated cost of training.
10
+ Training will begin only after you confirm the payment.
11
+
12
+ Please note that in order to use non-free tier AutoTrain, you need to have a valid payment method on file.
13
+ You can add your payment method in the [billing](https://huggingface.co/settings/billing) section.
14
+
15
+ Estimated cost will be displayed in the UI as follows:
16
+
17
+ ![cost](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/cost.png)
autotrain-advanced/docs/source/dreambooth.mdx ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DreamBooth
2
+
3
+ DreamBooth is a method to personalize text-to-image models like Stable Diffusion given just a few (3-5) images of a subject. It allows the model to generate contextualized images of the subject in different scenes, poses, and views.
4
+
5
+ ![DreamBooth Teaser](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/dreambooth1.jpeg)
6
+
7
+ ## Data Preparation
8
+
9
+ The data format for DreamBooth training is simple. All you need is images of a concept (e.g. a person) and a concept token.
10
+
11
+ ![DreamBooth Training](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/dreambooth2.png)
12
+
13
+ To train a dreambooth model, please select an appropriate model from the hub. You can also let AutoTrain decide the best model for you!
14
+ When choosing a model from the hub, please make sure you select the correct image size compatible with the model.
15
+
16
+ Same as other tasks, you also have an option to select the parameters manually or automatically using AutoTrain.
17
+
18
+ For each concept that you want to train, you must have a concept token and concept images. Concept token is nothing but a word that is not available in the dictionary.
autotrain-advanced/docs/source/getting_started.mdx ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation
2
+
3
+ There is no installation required! AutoTrain Advanced runs on Hugging Face Spaces. All you need to do is create a new space with the AutoTrain Advanced template: https://huggingface.co/new-space?template=autotrain-projects/autotrain-advanced. Please make sure you keep the space private.
4
+
5
+ ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_1.png)
6
+
7
+ Once you have selected Docker > AutoTrain template. You can click on "Create Space" and you will be redirected to your new space.
8
+
9
+ ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_2.png)
10
+
11
+ Once the space is build, you will see this screen:
12
+
13
+ ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_3.png)
14
+
15
+ You can find your token at https://huggingface.co/settings/token.
16
+
17
+ Note: you have to add HF_TOKEN as an environment variable in your space settings. To do so, click on the "Settings" button in the top right corner of your space, then click on "New Secret" in the "Repository Secrets" section and add a new variable with the name HF_TOKEN and your token as the value as shown below:
18
+
19
+ ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_4.png)
20
+
21
+ # Updating AutoTrain Advanced to Latest Version
22
+
23
+ We are constantly adding new features and tasks to AutoTrain Advanced. Its always a good idea to update your space to the latest version before starting a new project. An up-to-date version of AutoTrain Advanced will have the latest tasks, features and bug fixes! Updating is as easy as clicking on the "Factory reboot" button in the setting page of your space.
24
+
25
+ ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_5.png)
26
+
27
+ Please note that "restarting" a space will not update it to the latest version. You need to "Factory reboot" the space to update it to the latest version.
28
+
29
+ And now we are all set and we can start with our first project!
autotrain-advanced/docs/source/image_classification.mdx ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Classification
2
+
3
+ Image classification is a supervised learning problem: define a set of target classes (objects to identify in images), and train a model to recognize them using labeled example photos.
4
+ Using AutoTrain, its super-easy to train a state-of-the-art image classification model. Just upload a set of images, and AutoTrain will automatically train a model to classify them.
5
+
6
+ ## Data Preparation
7
+
8
+ The data for image classification must be in zip format, with each class in a separate subfolder. For example, if you want to classify cats and dogs, your zip file should look like this:
9
+
10
+ ```
11
+ cats_and_dogs.zip
12
+ ├── cats
13
+ │ ├── cat.1.jpg
14
+ │ ├── cat.2.jpg
15
+ │ ├── cat.3.jpg
16
+ │ └── ...
17
+ └── dogs
18
+ ├── dog.1.jpg
19
+ ├── dog.2.jpg
20
+ ├── dog.3.jpg
21
+ └── ...
22
+ ```
23
+
24
+ Some points to keep in mind:
25
+
26
+ - The zip file should contain multiple folders (the classes), each folder should contain images of a single class.
27
+ - The name of the folder should be the name of the class.
28
+ - The images must be jpeg, jpg or png.
29
+ - There should be at least 5 images per class.
30
+ - There should not be any other files in the zip file.
31
+ - There should not be any other folders inside the zip folder.
32
+
33
+ When train.zip is decompressed, it creates two folders: cats and dogs. these are the two categories for classification. The images for both categories are in their respective folders. You can have as many categories as you want.
34
+
35
+ ## Training
36
+
37
+ Once you have your data ready, you can upload it to AutoTrain and select model and parameters.
38
+ If the estimate looks good, click on `Create Project` button to start training.
39
+
40
+ ![Image Classification](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/image_classification_1.png)
autotrain-advanced/docs/source/index.mdx ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AutoTrain
2
+
3
+ 🤗 AutoTrain is a no-code tool for training state-of-the-art models for Natural Language Processing (NLP) tasks, for Computer Vision (CV) tasks, and for Speech tasks and even for Tabular tasks. It is built on top of the awesome tools developed by the Hugging Face team, and it is designed to be easy to use.
4
+
5
+ ## Who should use AutoTrain?
6
+
7
+ AutoTrain is for anyone who wants to train a state-of-the-art model for a NLP, CV, Speech or Tabular task, but doesn't want to spend time on the technical details of training a model. AutoTrain is also for anyone who wants to train a model for a custom dataset, but doesn't want to spend time on the technical details of training a model. Our goal is to make it easy for anyone to train a state-of-the-art model for any task and our focus is not just data scientists or machine learning engineers, but also non-technical users.
8
+
9
+ ## How to use AutoTrain?
10
+
11
+ We offer several ways to use AutoTrain:
12
+
13
+ - No code users with large number of data samples can use `AutoTrain Advanced` by creating a new space with AutoTrain Docker image: https://huggingface.co/new-space?template=autotrain-projects/autotrain-advanced. Please make sure you keep the space private.
14
+
15
+ - No code users with small number of data samples can use AutoTrain using the UI located at: https://ui.autotrain.huggingface.co/projects. Please note that this UI won't be updated with new tasks and features as frequently as AutoTrain Advanced.
16
+
17
+ - Developers can access and build on top of AutoTrain using python api or run AutoTrain Advanced UI locally. The python api is available in the `autotrain-advanced` package. You can install it using pip:
18
+
19
+ ```bash
20
+ pip install autotrain-advanced
21
+ ```
22
+
23
+ - Developers can also use the AutoTrain API directly. The API is available at: https://api.autotrain.huggingface.co/docs
24
+
25
+
26
+ ## What is AutoTrain Advanced?
27
+
28
+ AutoTrain Advanced processes your data either in a Hugging Face Space or locally (if installed locally using pip). This saves one time since the data processing is not done by the AutoTrain backend, resulting in your job not being queued. AutoTrain Advanced also allows you to use your own hardware (better CPU and RAM) to process the data, thus, making the data processing faster.
29
+
30
+ Using AutoTrain Advanced, advanced users can also control the hyperparameters used for training per job. This allows you to train multiple models with different hyperparameters and compare the results.
31
+
32
+ Everything else is the same as AutoTrain. You can use AutoTrain Advanced to train models for NLP, CV, Speech and Tabular tasks.
33
+
34
+ We recommend using [AutoTrain Advanced](https://huggingface.co/new-space?template=autotrain-projects/autotrain-advanced) since it is faster, more flexible and will have more supported tasks and features in the future.
autotrain-advanced/docs/source/llm_finetuning.mdx ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM Finetuning
2
+
3
+ With AutoTrain, you can easily finetune large language models (LLMs) on your own data!
4
+
5
+ AutoTrain supports the following types of LLM finetuning:
6
+
7
+ - Causal Language Modeling (CLM)
8
+ - Masked Language Modeling (MLM) [Coming Soon]
9
+
10
+ For LLM finetuning, only Hugging Face Hub model choice is available.
11
+ User needs to select a model from Hugging Face Hub, that they want to finetune and select the parameters on their own (Manual Parameter Selection),
12
+ or use AutoTrain's Auto Parameter Selection to automatically select the best parameters for the task.
13
+
14
+ ## Data Preparation
15
+
16
+ LLM finetuning accepts data in CSV format.
17
+ There are two modes for LLM finetuning: `generic` and `chat`.
18
+ An example dataset with both formats in the same dataset can be found here: https://huggingface.co/datasets/tatsu-lab/alpaca
19
+
20
+ ### Generic
21
+
22
+ In generic mode, only one column is required: `text`.
23
+ The user can take care of how the data is formatted for the task.
24
+ A sample instance for this format is presented below:
25
+
26
+ ```
27
+ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
28
+
29
+ ### Instruction: Evaluate this sentence for spelling and grammar mistakes
30
+
31
+ ### Input: He finnished his meal and left the resturant
32
+
33
+ ### Response: He finished his meal and left the restaurant.
34
+ ```
35
+
36
+ ![Generic LLM Finetuning](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/llm_1.png)
37
+
38
+ Please note that above is the format for instruction finetuning. But in the `generic` mode, you can also finetune on any other format as you want. The data can be changed according to the requirements.
39
+
40
+
41
+ ## Training
42
+
43
+ Once you have your data ready and estimate verified, you can start training your model by clicking the "Create Project" button.
autotrain-advanced/docs/source/model_choice.mdx ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Choice
2
+
3
+ AutoTrain can automagically select the best models for your task! However, you are also
4
+ allowed to choose the models you want to use. You can choose the most appropriate models
5
+ from the Hugging Face Hub.
6
+
7
+ ![autotrain-model-choice](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/model_choice_1.png)
8
+
9
+ ## AutoTrain Model Choice
10
+
11
+ To let AutoTrain choose the best models for your task, you can use the "AutoTrain"
12
+ in the "Model Choice" section. Once you choose AutoTrain mode, you no longer need to worry about model and parameter selection.
13
+ AutoTrain will automatically select the best models (and parameters) for your task.
14
+
15
+ ## Manual Model Choice
16
+
17
+ To choose the models manually, you can use the "HuggingFace Hub" in the "Model Choice" section.
18
+ For example, if you want to use if you are training a text classification task and want to choose Deberta V3 Base for your task
19
+ from https://huggingface.co/microsoft/deberta-v3-base,
20
+ You can choose "HuggingFace Hub" and then write the model name: `microsoft/deberta-v3-base` in the model name field.
21
+
22
+ ![hub-model-choice](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/hub_model_choice.png)
23
+
24
+ Please note that if you are selecting a hub model, you should make sure that it is compatible with your task, otherwise the training will fail.
autotrain-advanced/docs/source/param_choice.mdx ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Parameter Choice
2
+
3
+ Just like model choice, you can choose the parameters for your job in two ways: AutoTrain and Manual.
4
+
5
+ ## AutoTrain Mode
6
+
7
+ In the AutoTrain mode, the parameters for your task-model pair will be chosen automagically.
8
+ If you choose "AutoTrain" as model choice, you get the AutoTrain mode as the only option.
9
+ If you choose "HuggingFace Hub" as model choice, you get the the option to choose between AutoTrain and Manual mode for parameter choice.
10
+
11
+ An example of AutoTrain mode for a text classification task is shown below:
12
+
13
+ ![AutoTrain Parameter Choice](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/param_choice_1.png)
14
+
15
+ For most of the tasks in AutoTrain parameter selection mode, you will get "Number of Models" as the only parameter to choose. Some tasks like test-classification might ask you about the language of the dataset.
16
+ The more the number of models, the better the final results might be but it might be more expensive too!
17
+
18
+ ## Manual Mode
19
+
20
+ Manual model can be used only when you choose "HuggingFace Hub" as model choice. In this mode, you can choose the parameters for your task-model pair manually.
21
+ An example of Manual mode for a text classification task is shown below:
22
+
23
+ ![Manual Parameter Choice](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/param_choice_2.png)
24
+
25
+ In the manual mode, you have to add the jobs on your own. So, carefully select your parameters, click on "Add Job" and 💥.
autotrain-advanced/docs/source/support.mdx ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Help and Support
2
+
3
+ To get help and support for autotrain, there are 3 ways:
4
+
5
+ - [Create an issue](https://github.com/huggingface/autotrain-advanced/issues/new) in AutoTrain Advanced GitHub repository.
6
+
7
+ - [Ask in the Hugging Face Forum](https://discuss.huggingface.co/c/autotrain/16).
8
+
9
+ - [Email us](mailto:autotrain@hf.co) directly.
10
+
11
+
12
+ Please don't forget to mention your username and project name if you have a specific question about your project.
autotrain-advanced/docs/source/text_classification.mdx ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text Classification
2
+
3
+ Training a text classification model with AutoTrain is super-easy! Get your data ready in
4
+ proper format and then with just a few clicks, your state-of-the-art model will be ready to
5
+ be used in production.
6
+
7
+ ## Data Format
8
+
9
+ Let's train a model for classifying the sentiment of a movie review. The data should be
10
+ in the following CSV format:
11
+
12
+ ```csv
13
+ review,sentiment
14
+ "this movie is great",positive
15
+ "this movie is bad",negative
16
+ .
17
+ .
18
+ .
19
+ ```
20
+
21
+ As you can see, we have two columns in the CSV file. One column is the text and the other
22
+ is the label. The label can be any string. In this example, we have two labels: `positive`
23
+ and `negative`. You can have as many labels as you want.
24
+
25
+ If your CSV is huge, you can divide it into multiple CSV files and upload them separately.
26
+ Please make sure that the column names are the same in all CSV files.
27
+
28
+ One way to divide the CSV file using pandas is as follows:
29
+
30
+ ```python
31
+ import pandas as pd
32
+
33
+ # Set the chunk size
34
+ chunk_size = 1000
35
+ i = 1
36
+
37
+ # Open the CSV file and read it in chunks
38
+ for chunk in pd.read_csv('example.csv', chunksize=chunk_size):
39
+ # Save each chunk to a new file
40
+ chunk.to_csv(f'chunk_{i}.csv', index=False)
41
+ i += 1
42
+ ```
43
+
44
+ Once the data has been uploaded, you have to select the proper column mapping
45
+
46
+ ## Column Mapping
47
+
48
+ ![Column Mapping](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/text_classification_1.png)
49
+
50
+ In our example, the text column is called `review` and the label column is called `sentiment`.
51
+ Thus, we have to select `review` for the text column and `sentiment` for the label column.
52
+ Please note that, if column mapping is not done correctly, the training will fail.
53
+
54
+
55
+ ## Training
56
+
57
+ Once you have uploaded the data, selected the column mapping, and set the hyperparameters (AutoTrain or Manual mode), you can start the training.
58
+ To start the training, please confirm the estimated cost and click on the `Create Project` button.
59
+
60
+
autotrain-advanced/examples/text_classification_binary.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from uuid import uuid4
3
+
4
+ from datasets import load_dataset
5
+
6
+ from autotrain.dataset import AutoTrainDataset
7
+ from autotrain.project import Project
8
+
9
+
10
+ RANDOM_ID = str(uuid4())
11
+ DATASET = "imdb"
12
+ PROJECT_NAME = f"imdb_{RANDOM_ID}"
13
+ TASK = "text_binary_classification"
14
+ MODEL = "bert-base-uncased"
15
+
16
+ USERNAME = os.environ["AUTOTRAIN_USERNAME"]
17
+ TOKEN = os.environ["HF_TOKEN"]
18
+
19
+
20
+ if __name__ == "__main__":
21
+ dataset = load_dataset(DATASET)
22
+ train = dataset["train"]
23
+ validation = dataset["test"]
24
+
25
+ # convert to pandas dataframe
26
+ train_df = train.to_pandas()
27
+ validation_df = validation.to_pandas()
28
+
29
+ # prepare dataset for AutoTrain
30
+ dset = AutoTrainDataset(
31
+ train_data=[train_df],
32
+ valid_data=[validation_df],
33
+ task=TASK,
34
+ token=TOKEN,
35
+ project_name=PROJECT_NAME,
36
+ username=USERNAME,
37
+ column_mapping={"text": "text", "label": "label"},
38
+ percent_valid=None,
39
+ )
40
+ dset.prepare()
41
+
42
+ #
43
+ # How to get params for a task:
44
+ #
45
+ # from autotrain.params import Params
46
+ # params = Params(task=TASK, training_type="hub_model").get()
47
+ # print(params) to get full list of params for the task
48
+
49
+ # define params in proper format
50
+ job1 = {
51
+ "task": TASK,
52
+ "learning_rate": 1e-5,
53
+ "optimizer": "adamw_torch",
54
+ "scheduler": "linear",
55
+ "epochs": 5,
56
+ }
57
+
58
+ job2 = {
59
+ "task": TASK,
60
+ "learning_rate": 3e-5,
61
+ "optimizer": "adamw_torch",
62
+ "scheduler": "cosine",
63
+ "epochs": 5,
64
+ }
65
+
66
+ job3 = {
67
+ "task": TASK,
68
+ "learning_rate": 5e-5,
69
+ "optimizer": "sgd",
70
+ "scheduler": "cosine",
71
+ "epochs": 5,
72
+ }
73
+
74
+ jobs = [job1, job2, job3]
75
+ project = Project(dataset=dset, hub_model=MODEL, job_params=jobs)
76
+ project_id = project.create()
77
+ project.approve(project_id)
autotrain-advanced/examples/text_classification_multiclass.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from uuid import uuid4
3
+
4
+ from datasets import load_dataset
5
+
6
+ from autotrain.dataset import AutoTrainDataset
7
+ from autotrain.project import Project
8
+
9
+
10
+ RANDOM_ID = str(uuid4())
11
+ DATASET = "amazon_reviews_multi"
12
+ PROJECT_NAME = f"amazon_reviews_multi_{RANDOM_ID}"
13
+ TASK = "text_multi_class_classification"
14
+ MODEL = "bert-base-uncased"
15
+
16
+ USERNAME = os.environ["AUTOTRAIN_USERNAME"]
17
+ TOKEN = os.environ["HF_TOKEN"]
18
+
19
+
20
+ if __name__ == "__main__":
21
+ dataset = load_dataset(DATASET, "en")
22
+ train = dataset["train"]
23
+ validation = dataset["test"]
24
+
25
+ # convert to pandas dataframe
26
+ train_df = train.to_pandas()
27
+ validation_df = validation.to_pandas()
28
+
29
+ # prepare dataset for AutoTrain
30
+ dset = AutoTrainDataset(
31
+ train_data=[train_df],
32
+ valid_data=[validation_df],
33
+ task=TASK,
34
+ token=TOKEN,
35
+ project_name=PROJECT_NAME,
36
+ username=USERNAME,
37
+ column_mapping={"text": "review_body", "label": "stars"},
38
+ percent_valid=None,
39
+ )
40
+ dset.prepare()
41
+
42
+ #
43
+ # How to get params for a task:
44
+ #
45
+ # from autotrain.params import Params
46
+ # params = Params(task=TASK, training_type="hub_model").get()
47
+ # print(params) to get full list of params for the task
48
+
49
+ # define params in proper format
50
+ job1 = {
51
+ "task": TASK,
52
+ "learning_rate": 1e-5,
53
+ "optimizer": "adamw_torch",
54
+ "scheduler": "linear",
55
+ "epochs": 5,
56
+ }
57
+
58
+ job2 = {
59
+ "task": TASK,
60
+ "learning_rate": 3e-5,
61
+ "optimizer": "adamw_torch",
62
+ "scheduler": "cosine",
63
+ "epochs": 5,
64
+ }
65
+
66
+ job3 = {
67
+ "task": TASK,
68
+ "learning_rate": 5e-5,
69
+ "optimizer": "sgd",
70
+ "scheduler": "cosine",
71
+ "epochs": 5,
72
+ }
73
+
74
+ jobs = [job1, job2, job3]
75
+ project = Project(dataset=dset, hub_model=MODEL, job_params=jobs)
76
+ project_id = project.create()
77
+ project.approve(project_id)
autotrain-advanced/requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.3.1
2
+ codecarbon==2.2.3
3
+ datasets[vision]~=2.14.0
4
+ evaluate==0.3.0
5
+ ipadic==1.0.0
6
+ jiwer==3.0.2
7
+ joblib==1.3.1
8
+ loguru==0.7.0
9
+ pandas==2.0.3
10
+ Pillow==10.0.0
11
+ protobuf==4.23.4
12
+ pydantic==1.10.11
13
+ sacremoses==0.0.53
14
+ scikit-learn==1.3.0
15
+ sentencepiece==0.1.99
16
+ tqdm==4.65.0
17
+ werkzeug==2.3.6
18
+ huggingface_hub>=0.16.4
19
+ requests==2.31.0
20
+ gradio==3.39.0
21
+ einops==0.6.1
22
+ invisible-watermark==0.2.0
23
+ # latest versions
24
+ tensorboard
25
+ peft
26
+ trl
27
+ tiktoken
28
+ transformers
29
+ accelerate
30
+ diffusers
31
+ bitsandbytes
autotrain-advanced/setup.cfg ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ license_files = LICENSE
3
+ version = attr: autotrain.__version__
4
+
5
+ [isort]
6
+ ensure_newline_before_comments = True
7
+ force_grid_wrap = 0
8
+ include_trailing_comma = True
9
+ line_length = 119
10
+ lines_after_imports = 2
11
+ multi_line_output = 3
12
+ use_parentheses = True
13
+
14
+ [flake8]
15
+ ignore = E203, E501, W503
16
+ max-line-length = 119
17
+ per-file-ignores =
18
+ # imported but unused
19
+ __init__.py: F401
20
+ exclude =
21
+ .git,
22
+ .venv,
23
+ __pycache__,
24
+ dist
autotrain-advanced/setup.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lint as: python3
2
+ """
3
+ HuggingFace / AutoTrain Advanced
4
+ """
5
+ import os
6
+
7
+ from setuptools import find_packages, setup
8
+
9
+
10
+ DOCLINES = __doc__.split("\n")
11
+
12
+ this_directory = os.path.abspath(os.path.dirname(__file__))
13
+ with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f:
14
+ LONG_DESCRIPTION = f.read()
15
+
16
+ # get INSTALL_REQUIRES from requirements.txt
17
+ with open(os.path.join(this_directory, "requirements.txt"), encoding="utf-8") as f:
18
+ INSTALL_REQUIRES = f.read().splitlines()
19
+
20
+ QUALITY_REQUIRE = [
21
+ "black",
22
+ "isort",
23
+ "flake8==3.7.9",
24
+ ]
25
+
26
+ TESTS_REQUIRE = ["pytest"]
27
+
28
+
29
+ EXTRAS_REQUIRE = {
30
+ "dev": INSTALL_REQUIRES + QUALITY_REQUIRE + TESTS_REQUIRE,
31
+ "quality": INSTALL_REQUIRES + QUALITY_REQUIRE,
32
+ "docs": INSTALL_REQUIRES
33
+ + [
34
+ "recommonmark",
35
+ "sphinx==3.1.2",
36
+ "sphinx-markdown-tables",
37
+ "sphinx-rtd-theme==0.4.3",
38
+ "sphinx-copybutton",
39
+ ],
40
+ }
41
+
42
+ setup(
43
+ name="autotrain-advanced",
44
+ description=DOCLINES[0],
45
+ long_description=LONG_DESCRIPTION,
46
+ long_description_content_type="text/markdown",
47
+ author="HuggingFace Inc.",
48
+ author_email="autotrain@huggingface.co",
49
+ url="https://github.com/huggingface/autotrain-advanced",
50
+ download_url="https://github.com/huggingface/autotrain-advanced/tags",
51
+ license="Apache 2.0",
52
+ package_dir={"": "src"},
53
+ packages=find_packages("src"),
54
+ extras_require=EXTRAS_REQUIRE,
55
+ install_requires=INSTALL_REQUIRES,
56
+ entry_points={"console_scripts": ["autotrain=autotrain.cli.autotrain:main"]},
57
+ classifiers=[
58
+ "Development Status :: 5 - Production/Stable",
59
+ "Intended Audience :: Developers",
60
+ "Intended Audience :: Education",
61
+ "Intended Audience :: Science/Research",
62
+ "License :: OSI Approved :: Apache Software License",
63
+ "Operating System :: OS Independent",
64
+ "Programming Language :: Python :: 3.8",
65
+ "Programming Language :: Python :: 3.9",
66
+ "Programming Language :: Python :: 3.10",
67
+ "Programming Language :: Python :: 3.11",
68
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
69
+ ],
70
+ keywords="automl autonlp autotrain huggingface",
71
+ )
autotrain-advanced/src/autotrain/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020-2021 The HuggingFace AutoTrain Authors
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ # pylint: enable=line-too-long
18
+ import os
19
+
20
+
21
+ # ignore bnb warnings
22
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
23
+ # os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
24
+ __version__ = "0.6.16.dev0"
autotrain-advanced/src/autotrain/app.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import string
5
+ import zipfile
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ from huggingface_hub import list_models
10
+ from loguru import logger
11
+
12
+ from autotrain.dataset import AutoTrainDataset, AutoTrainDreamboothDataset, AutoTrainImageClassificationDataset
13
+ from autotrain.languages import SUPPORTED_LANGUAGES
14
+ from autotrain.params import Params
15
+ from autotrain.project import Project
16
+ from autotrain.utils import get_project_cost, get_user_token, user_authentication
17
+
18
+
19
+ APP_TASKS = {
20
+ "Natural Language Processing": ["Text Classification", "LLM Finetuning"],
21
+ # "Tabular": TABULAR_TASKS,
22
+ "Computer Vision": ["Image Classification", "Dreambooth"],
23
+ }
24
+
25
+ APP_TASKS_MAPPING = {
26
+ "Text Classification": "text_multi_class_classification",
27
+ "LLM Finetuning": "lm_training",
28
+ "Image Classification": "image_multi_class_classification",
29
+ "Dreambooth": "dreambooth",
30
+ }
31
+
32
+ APP_TASK_TYPE_MAPPING = {
33
+ "text_classification": "Natural Language Processing",
34
+ "lm_training": "Natural Language Processing",
35
+ "image_classification": "Computer Vision",
36
+ "dreambooth": "Computer Vision",
37
+ }
38
+
39
+ ALLOWED_FILE_TYPES = [
40
+ ".csv",
41
+ ".CSV",
42
+ ".jsonl",
43
+ ".JSONL",
44
+ ".zip",
45
+ ".ZIP",
46
+ ".png",
47
+ ".PNG",
48
+ ".jpg",
49
+ ".JPG",
50
+ ".jpeg",
51
+ ".JPEG",
52
+ ]
53
+
54
+
55
+ def _login_user(user_token):
56
+ user_info = user_authentication(token=user_token)
57
+ username = user_info["name"]
58
+
59
+ user_can_pay = user_info["canPay"]
60
+ orgs = user_info["orgs"]
61
+
62
+ valid_orgs = [org for org in orgs if org["canPay"] is True]
63
+ valid_orgs = [org for org in valid_orgs if org["roleInOrg"] in ("admin", "write")]
64
+ valid_orgs = [org["name"] for org in valid_orgs]
65
+
66
+ valid_can_pay = [username] + valid_orgs if user_can_pay else valid_orgs
67
+ who_is_training = [username] + [org["name"] for org in orgs]
68
+ return user_token, valid_can_pay, who_is_training
69
+
70
+
71
+ def _update_task_type(project_type):
72
+ return gr.Dropdown.update(
73
+ value=APP_TASKS[project_type][0],
74
+ choices=APP_TASKS[project_type],
75
+ visible=True,
76
+ )
77
+
78
+
79
+ def _update_model_choice(task, autotrain_backend):
80
+ # TODO: add tabular and remember, for tabular, we only support AutoTrain
81
+ if autotrain_backend.lower() != "huggingface internal":
82
+ model_choice = ["HuggingFace Hub"]
83
+ return gr.Dropdown.update(
84
+ value=model_choice[0],
85
+ choices=model_choice,
86
+ visible=True,
87
+ )
88
+
89
+ if task == "LLM Finetuning":
90
+ model_choice = ["HuggingFace Hub"]
91
+ else:
92
+ model_choice = ["AutoTrain", "HuggingFace Hub"]
93
+
94
+ return gr.Dropdown.update(
95
+ value=model_choice[0],
96
+ choices=model_choice,
97
+ visible=True,
98
+ )
99
+
100
+
101
+ def _update_file_type(task):
102
+ task = APP_TASKS_MAPPING[task]
103
+ if task in ("text_multi_class_classification", "lm_training"):
104
+ return gr.Radio.update(
105
+ value="CSV",
106
+ choices=["CSV", "JSONL"],
107
+ visible=True,
108
+ )
109
+ elif task == "image_multi_class_classification":
110
+ return gr.Radio.update(
111
+ value="ZIP",
112
+ choices=["Image Subfolders", "ZIP"],
113
+ visible=True,
114
+ )
115
+ elif task == "dreambooth":
116
+ return gr.Radio.update(
117
+ value="ZIP",
118
+ choices=["Image Folder", "ZIP"],
119
+ visible=True,
120
+ )
121
+ else:
122
+ raise NotImplementedError
123
+
124
+
125
+ def _update_param_choice(model_choice, autotrain_backend):
126
+ logger.info(f"model_choice: {model_choice}")
127
+ choices = ["AutoTrain", "Manual"] if model_choice == "HuggingFace Hub" else ["AutoTrain"]
128
+ choices = ["Manual"] if autotrain_backend != "HuggingFace Internal" else choices
129
+ return gr.Dropdown.update(
130
+ value=choices[0],
131
+ choices=choices,
132
+ visible=True,
133
+ )
134
+
135
+
136
+ def _project_type_update(project_type, task_type, autotrain_backend):
137
+ logger.info(f"project_type: {project_type}, task_type: {task_type}")
138
+ task_choices_update = _update_task_type(project_type)
139
+ model_choices_update = _update_model_choice(task_choices_update["value"], autotrain_backend)
140
+ param_choices_update = _update_param_choice(model_choices_update["value"], autotrain_backend)
141
+ return [
142
+ task_choices_update,
143
+ model_choices_update,
144
+ param_choices_update,
145
+ _update_hub_model_choices(task_choices_update["value"], model_choices_update["value"]),
146
+ ]
147
+
148
+
149
+ def _task_type_update(task_type, autotrain_backend):
150
+ logger.info(f"task_type: {task_type}")
151
+ model_choices_update = _update_model_choice(task_type, autotrain_backend)
152
+ param_choices_update = _update_param_choice(model_choices_update["value"], autotrain_backend)
153
+ return [
154
+ model_choices_update,
155
+ param_choices_update,
156
+ _update_hub_model_choices(task_type, model_choices_update["value"]),
157
+ ]
158
+
159
+
160
+ def _update_col_map(training_data, task):
161
+ task = APP_TASKS_MAPPING[task]
162
+ if task == "text_multi_class_classification":
163
+ data_cols = pd.read_csv(training_data[0].name, nrows=2).columns.tolist()
164
+ return [
165
+ gr.Dropdown.update(visible=True, choices=data_cols, label="Map `text` column", value=data_cols[0]),
166
+ gr.Dropdown.update(visible=True, choices=data_cols, label="Map `target` column", value=data_cols[1]),
167
+ gr.Text.update(visible=False),
168
+ ]
169
+ elif task == "lm_training":
170
+ data_cols = pd.read_csv(training_data[0].name, nrows=2).columns.tolist()
171
+ return [
172
+ gr.Dropdown.update(visible=True, choices=data_cols, label="Map `text` column", value=data_cols[0]),
173
+ gr.Dropdown.update(visible=False),
174
+ gr.Text.update(visible=False),
175
+ ]
176
+ elif task == "dreambooth":
177
+ return [
178
+ gr.Dropdown.update(visible=False),
179
+ gr.Dropdown.update(visible=False),
180
+ gr.Text.update(visible=True, label="Concept Token", interactive=True),
181
+ ]
182
+ else:
183
+ return [
184
+ gr.Dropdown.update(visible=False),
185
+ gr.Dropdown.update(visible=False),
186
+ gr.Text.update(visible=False),
187
+ ]
188
+
189
+
190
+ def _estimate_costs(
191
+ training_data, validation_data, task, user_token, autotrain_username, training_params_txt, autotrain_backend
192
+ ):
193
+ if autotrain_backend.lower() != "huggingface internal":
194
+ return [
195
+ gr.Markdown.update(
196
+ value="Cost estimation is not available for this backend",
197
+ visible=True,
198
+ ),
199
+ gr.Number.update(visible=False),
200
+ ]
201
+ try:
202
+ logger.info("Estimating costs....")
203
+ if training_data is None:
204
+ return [
205
+ gr.Markdown.update(
206
+ value="Could not estimate cost. Please add training data",
207
+ visible=True,
208
+ ),
209
+ gr.Number.update(visible=False),
210
+ ]
211
+ if validation_data is None:
212
+ validation_data = []
213
+
214
+ training_params = json.loads(training_params_txt)
215
+ if len(training_params) == 0:
216
+ return [
217
+ gr.Markdown.update(
218
+ value="Could not estimate cost. Please add atleast one job",
219
+ visible=True,
220
+ ),
221
+ gr.Number.update(visible=False),
222
+ ]
223
+ elif len(training_params) == 1:
224
+ if "num_models" in training_params[0]:
225
+ num_models = training_params[0]["num_models"]
226
+ else:
227
+ num_models = 1
228
+ else:
229
+ num_models = len(training_params)
230
+ task = APP_TASKS_MAPPING[task]
231
+ num_samples = 0
232
+ logger.info("Estimating number of samples")
233
+ if task in ("text_multi_class_classification", "lm_training"):
234
+ for _f in training_data:
235
+ num_samples += pd.read_csv(_f.name).shape[0]
236
+ for _f in validation_data:
237
+ num_samples += pd.read_csv(_f.name).shape[0]
238
+ elif task == "image_multi_class_classification":
239
+ logger.info(f"training_data: {training_data}")
240
+ if len(training_data) > 1:
241
+ return [
242
+ gr.Markdown.update(
243
+ value="Only one training file is supported for image classification",
244
+ visible=True,
245
+ ),
246
+ gr.Number.update(visible=False),
247
+ ]
248
+ if len(validation_data) > 1:
249
+ return [
250
+ gr.Markdown.update(
251
+ value="Only one validation file is supported for image classification",
252
+ visible=True,
253
+ ),
254
+ gr.Number.update(visible=False),
255
+ ]
256
+ for _f in training_data:
257
+ zip_ref = zipfile.ZipFile(_f.name, "r")
258
+ for _ in zip_ref.namelist():
259
+ num_samples += 1
260
+ for _f in validation_data:
261
+ zip_ref = zipfile.ZipFile(_f.name, "r")
262
+ for _ in zip_ref.namelist():
263
+ num_samples += 1
264
+ elif task == "dreambooth":
265
+ num_samples = len(training_data)
266
+ else:
267
+ raise NotImplementedError
268
+
269
+ logger.info(f"Estimating costs for: num_models: {num_models}, task: {task}, num_samples: {num_samples}")
270
+ estimated_cost = get_project_cost(
271
+ username=autotrain_username,
272
+ token=user_token,
273
+ task=task,
274
+ num_samples=num_samples,
275
+ num_models=num_models,
276
+ )
277
+ logger.info(f"Estimated_cost: {estimated_cost}")
278
+ return [
279
+ gr.Markdown.update(
280
+ value=f"Estimated cost: ${estimated_cost:.2f}. Note: clicking on 'Create Project' will start training and incur charges!",
281
+ visible=True,
282
+ ),
283
+ gr.Number.update(visible=False),
284
+ ]
285
+ except Exception as e:
286
+ logger.error(e)
287
+ logger.error("Could not estimate cost, check inputs")
288
+ return [
289
+ gr.Markdown.update(
290
+ value="Could not estimate cost, check inputs",
291
+ visible=True,
292
+ ),
293
+ gr.Number.update(visible=False),
294
+ ]
295
+
296
+
297
+ def get_job_params(param_choice, training_params, task):
298
+ if param_choice == "autotrain":
299
+ if len(training_params) > 1:
300
+ raise ValueError("❌ Only one job parameter is allowed for AutoTrain.")
301
+ training_params[0].update({"task": task})
302
+ elif param_choice.lower() == "manual":
303
+ for i in range(len(training_params)):
304
+ training_params[i].update({"task": task})
305
+ if "hub_model" in training_params[i]:
306
+ # remove hub_model from training_params
307
+ training_params[i].pop("hub_model")
308
+ return training_params
309
+
310
+
311
+ def _update_project_name():
312
+ random_project_name = "-".join(
313
+ ["".join(random.choices(string.ascii_lowercase + string.digits, k=4)) for _ in range(3)]
314
+ )
315
+ # check if training tracker exists
316
+ if os.path.exists(os.path.join("/tmp", "training")):
317
+ return [
318
+ gr.Text.update(value=random_project_name, visible=True, interactive=True),
319
+ gr.Button.update(interactive=False),
320
+ ]
321
+ return [
322
+ gr.Text.update(value=random_project_name, visible=True, interactive=True),
323
+ gr.Button.update(interactive=True),
324
+ ]
325
+
326
+
327
+ def _update_hub_model_choices(task, model_choice):
328
+ task = APP_TASKS_MAPPING[task]
329
+ logger.info(f"Updating hub model choices for task: {task}, model_choice: {model_choice}")
330
+ if model_choice.lower() == "autotrain":
331
+ return gr.Dropdown.update(
332
+ visible=False,
333
+ interactive=False,
334
+ )
335
+ if task == "text_multi_class_classification":
336
+ hub_models1 = list_models(filter="fill-mask", sort="downloads", direction=-1, limit=100)
337
+ hub_models2 = list_models(filter="text-classification", sort="downloads", direction=-1, limit=100)
338
+ hub_models = list(hub_models1) + list(hub_models2)
339
+ elif task == "lm_training":
340
+ hub_models = list(list_models(filter="text-generation", sort="downloads", direction=-1, limit=100))
341
+ elif task == "image_multi_class_classification":
342
+ hub_models = list(list_models(filter="image-classification", sort="downloads", direction=-1, limit=100))
343
+ elif task == "dreambooth":
344
+ hub_models = list(list_models(filter="text-to-image", sort="downloads", direction=-1, limit=100))
345
+ else:
346
+ raise NotImplementedError
347
+ # sort by number of downloads in descending order
348
+ hub_models = [{"id": m.modelId, "downloads": m.downloads} for m in hub_models if m.private is False]
349
+ hub_models = sorted(hub_models, key=lambda x: x["downloads"], reverse=True)
350
+
351
+ if task == "dreambooth":
352
+ choices = ["stabilityai/stable-diffusion-xl-base-1.0"] + [m["id"] for m in hub_models]
353
+ value = choices[0]
354
+ return gr.Dropdown.update(
355
+ choices=choices,
356
+ value=value,
357
+ visible=True,
358
+ interactive=True,
359
+ )
360
+
361
+ return gr.Dropdown.update(
362
+ choices=[m["id"] for m in hub_models],
363
+ value=hub_models[0]["id"],
364
+ visible=True,
365
+ interactive=True,
366
+ )
367
+
368
+
369
+ def _update_backend(backend):
370
+ if backend != "Hugging Face Internal":
371
+ return [
372
+ gr.Dropdown.update(
373
+ visible=True,
374
+ interactive=True,
375
+ choices=["HuggingFace Hub"],
376
+ value="HuggingFace Hub",
377
+ ),
378
+ gr.Dropdown.update(
379
+ visible=True,
380
+ interactive=True,
381
+ choices=["Manual"],
382
+ value="Manual",
383
+ ),
384
+ ]
385
+ return [
386
+ gr.Dropdown.update(
387
+ visible=True,
388
+ interactive=True,
389
+ ),
390
+ gr.Dropdown.update(
391
+ visible=True,
392
+ interactive=True,
393
+ ),
394
+ ]
395
+
396
+
397
+ def _create_project(
398
+ autotrain_username,
399
+ valid_can_pay,
400
+ project_name,
401
+ user_token,
402
+ task,
403
+ training_data,
404
+ validation_data,
405
+ col_map_text,
406
+ col_map_label,
407
+ concept_token,
408
+ training_params_txt,
409
+ hub_model,
410
+ estimated_cost,
411
+ autotrain_backend,
412
+ ):
413
+ task = APP_TASKS_MAPPING[task]
414
+ valid_can_pay = valid_can_pay.split(",")
415
+ can_pay = autotrain_username in valid_can_pay
416
+ logger.info(f"🚨🚨🚨Creating project: {project_name}")
417
+ logger.info(f"🚨Task: {task}")
418
+ logger.info(f"🚨Training data: {training_data}")
419
+ logger.info(f"🚨Validation data: {validation_data}")
420
+ logger.info(f"🚨Training params: {training_params_txt}")
421
+ logger.info(f"🚨Hub model: {hub_model}")
422
+ logger.info(f"🚨Estimated cost: {estimated_cost}")
423
+ logger.info(f"🚨:Can pay: {can_pay}")
424
+
425
+ if can_pay is False and estimated_cost > 0:
426
+ raise gr.Error("❌ You do not have enough credits to create this project. Please add a valid payment method.")
427
+
428
+ training_params = json.loads(training_params_txt)
429
+ if len(training_params) == 0:
430
+ raise gr.Error("Please add atleast one job")
431
+ elif len(training_params) == 1:
432
+ if "num_models" in training_params[0]:
433
+ param_choice = "autotrain"
434
+ else:
435
+ param_choice = "manual"
436
+ else:
437
+ param_choice = "manual"
438
+
439
+ if task == "image_multi_class_classification":
440
+ training_data = training_data[0].name
441
+ if validation_data is not None:
442
+ validation_data = validation_data[0].name
443
+ dset = AutoTrainImageClassificationDataset(
444
+ train_data=training_data,
445
+ token=user_token,
446
+ project_name=project_name,
447
+ username=autotrain_username,
448
+ valid_data=validation_data,
449
+ percent_valid=None, # TODO: add to UI
450
+ )
451
+ elif task == "text_multi_class_classification":
452
+ training_data = [f.name for f in training_data]
453
+ if validation_data is None:
454
+ validation_data = []
455
+ else:
456
+ validation_data = [f.name for f in validation_data]
457
+ dset = AutoTrainDataset(
458
+ train_data=training_data,
459
+ task=task,
460
+ token=user_token,
461
+ project_name=project_name,
462
+ username=autotrain_username,
463
+ column_mapping={"text": col_map_text, "label": col_map_label},
464
+ valid_data=validation_data,
465
+ percent_valid=None, # TODO: add to UI
466
+ )
467
+ elif task == "lm_training":
468
+ training_data = [f.name for f in training_data]
469
+ if validation_data is None:
470
+ validation_data = []
471
+ else:
472
+ validation_data = [f.name for f in validation_data]
473
+ dset = AutoTrainDataset(
474
+ train_data=training_data,
475
+ task=task,
476
+ token=user_token,
477
+ project_name=project_name,
478
+ username=autotrain_username,
479
+ column_mapping={"text": col_map_text},
480
+ valid_data=validation_data,
481
+ percent_valid=None, # TODO: add to UI
482
+ )
483
+ elif task == "dreambooth":
484
+ dset = AutoTrainDreamboothDataset(
485
+ concept_images=training_data,
486
+ concept_name=concept_token,
487
+ token=user_token,
488
+ project_name=project_name,
489
+ username=autotrain_username,
490
+ )
491
+ else:
492
+ raise NotImplementedError
493
+
494
+ dset.prepare()
495
+ project = Project(
496
+ dataset=dset,
497
+ param_choice=param_choice,
498
+ hub_model=hub_model,
499
+ job_params=get_job_params(param_choice, training_params, task),
500
+ )
501
+ if autotrain_backend.lower() == "huggingface internal":
502
+ project_id = project.create()
503
+ project.approve(project_id)
504
+ return gr.Markdown.update(
505
+ value=f"Project created successfully. Monitor progess on the [dashboard](https://ui.autotrain.huggingface.co/{project_id}/trainings).",
506
+ visible=True,
507
+ )
508
+ else:
509
+ project.create(local=True)
510
+
511
+
512
+ def get_variable_name(var, namespace):
513
+ for name in namespace:
514
+ if namespace[name] is var:
515
+ return name
516
+ return None
517
+
518
+
519
+ def disable_create_project_button():
520
+ return gr.Button.update(interactive=False)
521
+
522
+
523
+ def main():
524
+ with gr.Blocks(theme="freddyaboulton/dracula_revamped") as demo:
525
+ gr.Markdown("## 🤗 AutoTrain Advanced")
526
+ user_token = os.environ.get("HF_TOKEN", "")
527
+
528
+ if len(user_token) == 0:
529
+ user_token = get_user_token()
530
+
531
+ if user_token is None:
532
+ gr.Markdown(
533
+ """Please login with a write [token](https://huggingface.co/settings/tokens).
534
+ Pass your HF token in an environment variable called `HF_TOKEN` and then restart this app.
535
+ """
536
+ )
537
+ return demo
538
+
539
+ user_token, valid_can_pay, who_is_training = _login_user(user_token)
540
+
541
+ if user_token is None or len(user_token) == 0:
542
+ gr.Error("Please login with a write token.")
543
+
544
+ user_token = gr.Textbox(
545
+ value=user_token, type="password", lines=1, max_lines=1, visible=False, interactive=False
546
+ )
547
+ valid_can_pay = gr.Textbox(value=",".join(valid_can_pay), visible=False, interactive=False)
548
+ with gr.Row():
549
+ with gr.Column():
550
+ with gr.Row():
551
+ autotrain_username = gr.Dropdown(
552
+ label="AutoTrain Username",
553
+ choices=who_is_training,
554
+ value=who_is_training[0] if who_is_training else "",
555
+ )
556
+ autotrain_backend = gr.Dropdown(
557
+ label="AutoTrain Backend",
558
+ choices=["HuggingFace Internal", "HuggingFace Spaces"],
559
+ value="HuggingFace Internal",
560
+ interactive=True,
561
+ )
562
+ with gr.Row():
563
+ project_name = gr.Textbox(label="Project name", value="", lines=1, max_lines=1, interactive=True)
564
+ project_type = gr.Dropdown(
565
+ label="Project Type", choices=list(APP_TASKS.keys()), value=list(APP_TASKS.keys())[0]
566
+ )
567
+ task_type = gr.Dropdown(
568
+ label="Task",
569
+ choices=APP_TASKS[list(APP_TASKS.keys())[0]],
570
+ value=APP_TASKS[list(APP_TASKS.keys())[0]][0],
571
+ interactive=True,
572
+ )
573
+ model_choice = gr.Dropdown(
574
+ label="Model Choice",
575
+ choices=["AutoTrain", "HuggingFace Hub"],
576
+ value="AutoTrain",
577
+ visible=True,
578
+ interactive=True,
579
+ )
580
+ hub_model = gr.Dropdown(
581
+ label="Hub Model",
582
+ value="",
583
+ visible=False,
584
+ interactive=True,
585
+ elem_id="hub_model",
586
+ )
587
+ gr.Markdown("<hr>")
588
+ with gr.Row():
589
+ with gr.Column():
590
+ with gr.Tabs(elem_id="tabs"):
591
+ with gr.TabItem("Data"):
592
+ with gr.Column():
593
+ # file_type_training = gr.Radio(
594
+ # label="File Type",
595
+ # choices=["CSV", "JSONL"],
596
+ # value="CSV",
597
+ # visible=True,
598
+ # interactive=True,
599
+ # )
600
+ training_data = gr.File(
601
+ label="Training Data",
602
+ file_types=ALLOWED_FILE_TYPES,
603
+ file_count="multiple",
604
+ visible=True,
605
+ interactive=True,
606
+ elem_id="training_data_box",
607
+ )
608
+ with gr.Accordion("Validation Data (Optional)", open=False):
609
+ validation_data = gr.File(
610
+ label="Validation Data (Optional)",
611
+ file_types=ALLOWED_FILE_TYPES,
612
+ file_count="multiple",
613
+ visible=True,
614
+ interactive=True,
615
+ elem_id="validation_data_box",
616
+ )
617
+ with gr.Row():
618
+ col_map_text = gr.Dropdown(
619
+ label="Text Column", choices=[], visible=False, interactive=True
620
+ )
621
+ col_map_target = gr.Dropdown(
622
+ label="Target Column", choices=[], visible=False, interactive=True
623
+ )
624
+ concept_token = gr.Text(
625
+ value="", visible=False, interactive=True, lines=1, max_lines=1
626
+ )
627
+ with gr.TabItem("Params"):
628
+ with gr.Row():
629
+ source_language = gr.Dropdown(
630
+ label="Source Language",
631
+ choices=SUPPORTED_LANGUAGES[:-1],
632
+ value="en",
633
+ visible=True,
634
+ interactive=True,
635
+ elem_id="source_language",
636
+ )
637
+ num_models = gr.Slider(
638
+ label="Number of Models",
639
+ minimum=1,
640
+ maximum=25,
641
+ value=5,
642
+ step=1,
643
+ visible=True,
644
+ interactive=True,
645
+ elem_id="num_models",
646
+ )
647
+ target_language = gr.Dropdown(
648
+ label="Target Language",
649
+ choices=["fr"],
650
+ value="fr",
651
+ visible=False,
652
+ interactive=True,
653
+ elem_id="target_language",
654
+ )
655
+ image_size = gr.Number(
656
+ label="Image Size",
657
+ value=512,
658
+ visible=False,
659
+ interactive=True,
660
+ elem_id="image_size",
661
+ )
662
+
663
+ with gr.Row():
664
+ learning_rate = gr.Number(
665
+ label="Learning Rate",
666
+ value=5e-5,
667
+ visible=False,
668
+ interactive=True,
669
+ elem_id="learning_rate",
670
+ )
671
+ batch_size = gr.Number(
672
+ label="Train Batch Size",
673
+ value=32,
674
+ visible=False,
675
+ interactive=True,
676
+ elem_id="train_batch_size",
677
+ )
678
+ num_epochs = gr.Number(
679
+ label="Number of Epochs",
680
+ value=3,
681
+ visible=False,
682
+ interactive=True,
683
+ elem_id="num_train_epochs",
684
+ )
685
+ with gr.Row():
686
+ gradient_accumulation_steps = gr.Number(
687
+ label="Gradient Accumulation Steps",
688
+ value=1,
689
+ visible=False,
690
+ interactive=True,
691
+ elem_id="gradient_accumulation_steps",
692
+ )
693
+ percentage_warmup_steps = gr.Number(
694
+ label="Percentage of Warmup Steps",
695
+ value=0.1,
696
+ visible=False,
697
+ interactive=True,
698
+ elem_id="percentage_warmup",
699
+ )
700
+ weight_decay = gr.Number(
701
+ label="Weight Decay",
702
+ value=0.01,
703
+ visible=False,
704
+ interactive=True,
705
+ elem_id="weight_decay",
706
+ )
707
+ with gr.Row():
708
+ lora_r = gr.Number(
709
+ label="LoraR",
710
+ value=16,
711
+ visible=False,
712
+ interactive=True,
713
+ elem_id="lora_r",
714
+ )
715
+ lora_alpha = gr.Number(
716
+ label="LoraAlpha",
717
+ value=32,
718
+ visible=False,
719
+ interactive=True,
720
+ elem_id="lora_alpha",
721
+ )
722
+ lora_dropout = gr.Number(
723
+ label="Lora Dropout",
724
+ value=0.1,
725
+ visible=False,
726
+ interactive=True,
727
+ elem_id="lora_dropout",
728
+ )
729
+ with gr.Row():
730
+ db_num_steps = gr.Number(
731
+ label="Num Steps",
732
+ value=500,
733
+ visible=False,
734
+ interactive=True,
735
+ elem_id="num_steps",
736
+ )
737
+ with gr.Row():
738
+ optimizer = gr.Dropdown(
739
+ label="Optimizer",
740
+ choices=["adamw_torch", "adamw_hf", "sgd", "adafactor", "adagrad"],
741
+ value="adamw_torch",
742
+ visible=False,
743
+ interactive=True,
744
+ elem_id="optimizer",
745
+ )
746
+ scheduler = gr.Dropdown(
747
+ label="Scheduler",
748
+ choices=["linear", "cosine"],
749
+ value="linear",
750
+ visible=False,
751
+ interactive=True,
752
+ elem_id="scheduler",
753
+ )
754
+
755
+ add_job_button = gr.Button(
756
+ value="Add Job",
757
+ visible=True,
758
+ interactive=True,
759
+ elem_id="add_job",
760
+ )
761
+ # clear_jobs_button = gr.Button(
762
+ # value="Clear Jobs",
763
+ # visible=True,
764
+ # interactive=True,
765
+ # elem_id="clear_jobs",
766
+ # )
767
+ gr.Markdown("<hr>")
768
+ estimated_costs_md = gr.Markdown(value="Estimated Costs: N/A", visible=True, interactive=False)
769
+ estimated_costs_num = gr.Number(value=0, visible=False, interactive=False)
770
+ create_project_button = gr.Button(
771
+ value="Create Project",
772
+ visible=True,
773
+ interactive=True,
774
+ elem_id="create_project",
775
+ )
776
+ with gr.Column():
777
+ param_choice = gr.Dropdown(
778
+ label="Param Choice",
779
+ choices=["AutoTrain"],
780
+ value="AutoTrain",
781
+ visible=True,
782
+ interactive=True,
783
+ )
784
+ training_params_txt = gr.Text(value="[]", visible=False, interactive=False)
785
+ training_params_md = gr.DataFrame(visible=False, interactive=False)
786
+
787
+ final_output = gr.Markdown(value="", visible=True, interactive=False)
788
+ hyperparameters = [
789
+ hub_model,
790
+ num_models,
791
+ source_language,
792
+ target_language,
793
+ learning_rate,
794
+ batch_size,
795
+ num_epochs,
796
+ gradient_accumulation_steps,
797
+ lora_r,
798
+ lora_alpha,
799
+ lora_dropout,
800
+ optimizer,
801
+ scheduler,
802
+ percentage_warmup_steps,
803
+ weight_decay,
804
+ db_num_steps,
805
+ image_size,
806
+ ]
807
+
808
+ def _update_params(params_data):
809
+ _task = params_data[task_type]
810
+ _task = APP_TASKS_MAPPING[_task]
811
+ params = Params(
812
+ task=_task,
813
+ param_choice="autotrain" if params_data[param_choice] == "AutoTrain" else "manual",
814
+ model_choice="autotrain" if params_data[model_choice] == "AutoTrain" else "hub_model",
815
+ )
816
+ params = params.get()
817
+ visible_params = []
818
+ for param in hyperparameters:
819
+ if param.elem_id in params.keys():
820
+ visible_params.append(param.elem_id)
821
+ op = [h.update(visible=h.elem_id in visible_params) for h in hyperparameters]
822
+ op.append(add_job_button.update(visible=True))
823
+ op.append(training_params_md.update(visible=False))
824
+ op.append(training_params_txt.update(value="[]"))
825
+ return op
826
+
827
+ autotrain_backend.change(
828
+ _project_type_update,
829
+ inputs=[project_type, task_type, autotrain_backend],
830
+ outputs=[task_type, model_choice, param_choice, hub_model],
831
+ )
832
+
833
+ project_type.change(
834
+ _project_type_update,
835
+ inputs=[project_type, task_type, autotrain_backend],
836
+ outputs=[task_type, model_choice, param_choice, hub_model],
837
+ )
838
+ task_type.change(
839
+ _task_type_update,
840
+ inputs=[task_type, autotrain_backend],
841
+ outputs=[model_choice, param_choice, hub_model],
842
+ )
843
+ model_choice.change(
844
+ _update_param_choice,
845
+ inputs=[model_choice, autotrain_backend],
846
+ outputs=param_choice,
847
+ ).then(
848
+ _update_hub_model_choices,
849
+ inputs=[task_type, model_choice],
850
+ outputs=hub_model,
851
+ )
852
+
853
+ param_choice.change(
854
+ _update_params,
855
+ inputs=set([task_type, param_choice, model_choice] + hyperparameters + [add_job_button]),
856
+ outputs=hyperparameters + [add_job_button, training_params_md, training_params_txt],
857
+ )
858
+ task_type.change(
859
+ _update_params,
860
+ inputs=set([task_type, param_choice, model_choice] + hyperparameters + [add_job_button]),
861
+ outputs=hyperparameters + [add_job_button, training_params_md, training_params_txt],
862
+ )
863
+ model_choice.change(
864
+ _update_params,
865
+ inputs=set([task_type, param_choice, model_choice] + hyperparameters + [add_job_button]),
866
+ outputs=hyperparameters + [add_job_button, training_params_md, training_params_txt],
867
+ )
868
+
869
+ def _add_job(params_data):
870
+ _task = params_data[task_type]
871
+ _task = APP_TASKS_MAPPING[_task]
872
+ _param_choice = "autotrain" if params_data[param_choice] == "AutoTrain" else "manual"
873
+ _model_choice = "autotrain" if params_data[model_choice] == "AutoTrain" else "hub_model"
874
+ if _model_choice == "hub_model" and params_data[hub_model] is None:
875
+ logger.error("Hub model is None")
876
+ return
877
+ _training_params = {}
878
+ params = Params(task=_task, param_choice=_param_choice, model_choice=_model_choice)
879
+ params = params.get()
880
+ for _param in hyperparameters:
881
+ if _param.elem_id in params.keys():
882
+ _training_params[_param.elem_id] = params_data[_param]
883
+ _training_params_md = json.loads(params_data[training_params_txt])
884
+ if _param_choice == "autotrain":
885
+ if len(_training_params_md) > 0:
886
+ _training_params_md[0] = _training_params
887
+ _training_params_md = _training_params_md[:1]
888
+ else:
889
+ _training_params_md.append(_training_params)
890
+ else:
891
+ _training_params_md.append(_training_params)
892
+ params_df = pd.DataFrame(_training_params_md)
893
+ # remove hub_model column
894
+ if "hub_model" in params_df.columns:
895
+ params_df = params_df.drop(columns=["hub_model"])
896
+ return [
897
+ gr.DataFrame.update(value=params_df, visible=True),
898
+ gr.Textbox.update(value=json.dumps(_training_params_md), visible=False),
899
+ ]
900
+
901
+ add_job_button.click(
902
+ _add_job,
903
+ inputs=set(
904
+ [task_type, param_choice, model_choice] + hyperparameters + [training_params_md, training_params_txt]
905
+ ),
906
+ outputs=[training_params_md, training_params_txt],
907
+ )
908
+ col_map_components = [
909
+ col_map_text,
910
+ col_map_target,
911
+ concept_token,
912
+ ]
913
+ training_data.change(
914
+ _update_col_map,
915
+ inputs=[training_data, task_type],
916
+ outputs=col_map_components,
917
+ )
918
+ task_type.change(
919
+ _update_col_map,
920
+ inputs=[training_data, task_type],
921
+ outputs=col_map_components,
922
+ )
923
+ estimate_costs_inputs = [
924
+ training_data,
925
+ validation_data,
926
+ task_type,
927
+ user_token,
928
+ autotrain_username,
929
+ training_params_txt,
930
+ autotrain_backend,
931
+ ]
932
+ estimate_costs_outputs = [estimated_costs_md, estimated_costs_num]
933
+ training_data.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
934
+ validation_data.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
935
+ training_params_txt.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
936
+ task_type.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
937
+ add_job_button.click(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
938
+
939
+ create_project_button.click(disable_create_project_button, None, create_project_button).then(
940
+ _create_project,
941
+ inputs=[
942
+ autotrain_username,
943
+ valid_can_pay,
944
+ project_name,
945
+ user_token,
946
+ task_type,
947
+ training_data,
948
+ validation_data,
949
+ col_map_text,
950
+ col_map_target,
951
+ concept_token,
952
+ training_params_txt,
953
+ hub_model,
954
+ estimated_costs_num,
955
+ autotrain_backend,
956
+ ],
957
+ outputs=final_output,
958
+ )
959
+
960
+ demo.load(
961
+ _update_project_name,
962
+ outputs=[project_name, create_project_button],
963
+ )
964
+
965
+ return demo
autotrain-advanced/src/autotrain/cli/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from argparse import ArgumentParser
3
+
4
+
5
+ class BaseAutoTrainCommand(ABC):
6
+ @staticmethod
7
+ @abstractmethod
8
+ def register_subcommand(parser: ArgumentParser):
9
+ raise NotImplementedError()
10
+
11
+ @abstractmethod
12
+ def run(self):
13
+ raise NotImplementedError()
autotrain-advanced/src/autotrain/cli/accelerated_autotrain.py ADDED
File without changes
autotrain-advanced/src/autotrain/cli/autotrain.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from .. import __version__
4
+ from .run_app import RunAutoTrainAppCommand
5
+ from .run_dreambooth import RunAutoTrainDreamboothCommand
6
+ from .run_llm import RunAutoTrainLLMCommand
7
+ from .run_setup import RunSetupCommand
8
+
9
+
10
+ def main():
11
+ parser = argparse.ArgumentParser(
12
+ "AutoTrain advanced CLI",
13
+ usage="autotrain <command> [<args>]",
14
+ epilog="For more information about a command, run: `autotrain <command> --help`",
15
+ )
16
+ parser.add_argument("--version", "-v", help="Display AutoTrain version", action="store_true")
17
+ commands_parser = parser.add_subparsers(help="commands")
18
+
19
+ # Register commands
20
+ RunAutoTrainAppCommand.register_subcommand(commands_parser)
21
+ RunAutoTrainLLMCommand.register_subcommand(commands_parser)
22
+ RunSetupCommand.register_subcommand(commands_parser)
23
+ RunAutoTrainDreamboothCommand.register_subcommand(commands_parser)
24
+
25
+ args = parser.parse_args()
26
+
27
+ if args.version:
28
+ print(__version__)
29
+ exit(0)
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ command = args.func(args)
36
+ command.run()
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
autotrain-advanced/src/autotrain/cli/run_app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+
3
+ from . import BaseAutoTrainCommand
4
+
5
+
6
+ def run_app_command_factory(args):
7
+ return RunAutoTrainAppCommand(
8
+ args.port,
9
+ args.host,
10
+ args.task,
11
+ )
12
+
13
+
14
+ class RunAutoTrainAppCommand(BaseAutoTrainCommand):
15
+ @staticmethod
16
+ def register_subcommand(parser: ArgumentParser):
17
+ run_app_parser = parser.add_parser(
18
+ "app",
19
+ description="✨ Run AutoTrain app",
20
+ )
21
+ run_app_parser.add_argument(
22
+ "--port",
23
+ type=int,
24
+ default=7860,
25
+ help="Port to run the app on",
26
+ required=False,
27
+ )
28
+ run_app_parser.add_argument(
29
+ "--host",
30
+ type=str,
31
+ default="127.0.0.1",
32
+ help="Host to run the app on",
33
+ required=False,
34
+ )
35
+ run_app_parser.add_argument(
36
+ "--task",
37
+ type=str,
38
+ required=False,
39
+ help="Task to run",
40
+ )
41
+ run_app_parser.set_defaults(func=run_app_command_factory)
42
+
43
+ def __init__(self, port, host, task):
44
+ self.port = port
45
+ self.host = host
46
+ self.task = task
47
+
48
+ def run(self):
49
+ if self.task == "dreambooth":
50
+ from ..dreambooth_app import main
51
+ else:
52
+ from ..app import main
53
+
54
+ demo = main()
55
+ demo.queue(concurrency_count=10).launch()
autotrain-advanced/src/autotrain/cli/run_dreambooth.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from argparse import ArgumentParser
4
+
5
+ from loguru import logger
6
+
7
+ from autotrain.cli import BaseAutoTrainCommand
8
+
9
+
10
+ try:
11
+ from autotrain.trainers.dreambooth import train as train_dreambooth
12
+ from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
13
+ from autotrain.trainers.dreambooth.utils import VALID_IMAGE_EXTENSIONS, XL_MODELS
14
+ except ImportError:
15
+ logger.warning(
16
+ "❌ Some DreamBooth components are missing! Please run `autotrain setup` to install it. Ignore this warning if you are not using DreamBooth or running `autotrain setup` already."
17
+ )
18
+
19
+
20
+ def count_images(directory):
21
+ files_grabbed = []
22
+ for files in VALID_IMAGE_EXTENSIONS:
23
+ files_grabbed.extend(glob.glob(os.path.join(directory, "*" + files)))
24
+ return len(files_grabbed)
25
+
26
+
27
+ def run_dreambooth_command_factory(args):
28
+ return RunAutoTrainDreamboothCommand(args)
29
+
30
+
31
+ class RunAutoTrainDreamboothCommand(BaseAutoTrainCommand):
32
+ @staticmethod
33
+ def register_subcommand(parser: ArgumentParser):
34
+ arg_list = [
35
+ {
36
+ "arg": "--model",
37
+ "help": "Model to use for training",
38
+ "required": True,
39
+ "type": str,
40
+ },
41
+ {
42
+ "arg": "--revision",
43
+ "help": "Model revision to use for training",
44
+ "required": False,
45
+ "type": str,
46
+ },
47
+ {
48
+ "arg": "--tokenizer",
49
+ "help": "Tokenizer to use for training",
50
+ "required": False,
51
+ "type": str,
52
+ },
53
+ {
54
+ "arg": "--image-path",
55
+ "help": "Path to the images",
56
+ "required": True,
57
+ "type": str,
58
+ },
59
+ {
60
+ "arg": "--class-image-path",
61
+ "help": "Path to the class images",
62
+ "required": False,
63
+ "type": str,
64
+ },
65
+ {
66
+ "arg": "--prompt",
67
+ "help": "Instance prompt",
68
+ "required": True,
69
+ "type": str,
70
+ },
71
+ {
72
+ "arg": "--class-prompt",
73
+ "help": "Class prompt",
74
+ "required": False,
75
+ "type": str,
76
+ },
77
+ {
78
+ "arg": "--num-class-images",
79
+ "help": "Number of class images",
80
+ "required": False,
81
+ "default": 100,
82
+ "type": int,
83
+ },
84
+ {
85
+ "arg": "--class-labels-conditioning",
86
+ "help": "Class labels conditioning",
87
+ "required": False,
88
+ "type": str,
89
+ },
90
+ {
91
+ "arg": "--prior-preservation",
92
+ "help": "With prior preservation",
93
+ "required": False,
94
+ "action": "store_true",
95
+ },
96
+ {
97
+ "arg": "--prior-loss-weight",
98
+ "help": "Prior loss weight",
99
+ "required": False,
100
+ "default": 1.0,
101
+ "type": float,
102
+ },
103
+ {
104
+ "arg": "--output",
105
+ "help": "Output directory",
106
+ "required": True,
107
+ "type": str,
108
+ },
109
+ {
110
+ "arg": "--seed",
111
+ "help": "Seed",
112
+ "required": False,
113
+ "default": 42,
114
+ "type": int,
115
+ },
116
+ {
117
+ "arg": "--resolution",
118
+ "help": "Resolution",
119
+ "required": True,
120
+ "type": int,
121
+ },
122
+ {
123
+ "arg": "--center-crop",
124
+ "help": "Center crop",
125
+ "required": False,
126
+ "action": "store_true",
127
+ },
128
+ {
129
+ "arg": "--train-text-encoder",
130
+ "help": "Train text encoder",
131
+ "required": False,
132
+ "action": "store_true",
133
+ },
134
+ {
135
+ "arg": "--batch-size",
136
+ "help": "Train batch size",
137
+ "required": False,
138
+ "default": 4,
139
+ "type": int,
140
+ },
141
+ {
142
+ "arg": "--sample-batch-size",
143
+ "help": "Sample batch size",
144
+ "required": False,
145
+ "default": 4,
146
+ "type": int,
147
+ },
148
+ {
149
+ "arg": "--epochs",
150
+ "help": "Number of training epochs",
151
+ "required": False,
152
+ "default": 1,
153
+ "type": int,
154
+ },
155
+ {
156
+ "arg": "--num-steps",
157
+ "help": "Max train steps",
158
+ "required": False,
159
+ "type": int,
160
+ },
161
+ {
162
+ "arg": "--checkpointing-steps",
163
+ "help": "Checkpointing steps",
164
+ "required": False,
165
+ "default": 100000,
166
+ "type": int,
167
+ },
168
+ {
169
+ "arg": "--resume-from-checkpoint",
170
+ "help": "Resume from checkpoint",
171
+ "required": False,
172
+ "type": str,
173
+ },
174
+ {
175
+ "arg": "--gradient-accumulation",
176
+ "help": "Gradient accumulation steps",
177
+ "required": False,
178
+ "default": 1,
179
+ "type": int,
180
+ },
181
+ {
182
+ "arg": "--gradient-checkpointing",
183
+ "help": "Gradient checkpointing",
184
+ "required": False,
185
+ "action": "store_true",
186
+ },
187
+ {
188
+ "arg": "--lr",
189
+ "help": "Learning rate",
190
+ "required": False,
191
+ "default": 5e-4,
192
+ "type": float,
193
+ },
194
+ {
195
+ "arg": "--scale-lr",
196
+ "help": "Scale learning rate",
197
+ "required": False,
198
+ "action": "store_true",
199
+ },
200
+ {
201
+ "arg": "--scheduler",
202
+ "help": "Learning rate scheduler",
203
+ "required": False,
204
+ "default": "constant",
205
+ },
206
+ {
207
+ "arg": "--warmup-steps",
208
+ "help": "Learning rate warmup steps",
209
+ "required": False,
210
+ "default": 0,
211
+ "type": int,
212
+ },
213
+ {
214
+ "arg": "--num-cycles",
215
+ "help": "Learning rate num cycles",
216
+ "required": False,
217
+ "default": 1,
218
+ "type": int,
219
+ },
220
+ {
221
+ "arg": "--lr-power",
222
+ "help": "Learning rate power",
223
+ "required": False,
224
+ "default": 1.0,
225
+ "type": float,
226
+ },
227
+ {
228
+ "arg": "--dataloader-num-workers",
229
+ "help": "Dataloader num workers",
230
+ "required": False,
231
+ "default": 0,
232
+ "type": int,
233
+ },
234
+ {
235
+ "arg": "--use-8bit-adam",
236
+ "help": "Use 8bit adam",
237
+ "required": False,
238
+ "action": "store_true",
239
+ },
240
+ {
241
+ "arg": "--adam-beta1",
242
+ "help": "Adam beta 1",
243
+ "required": False,
244
+ "default": 0.9,
245
+ "type": float,
246
+ },
247
+ {
248
+ "arg": "--adam-beta2",
249
+ "help": "Adam beta 2",
250
+ "required": False,
251
+ "default": 0.999,
252
+ "type": float,
253
+ },
254
+ {
255
+ "arg": "--adam-weight-decay",
256
+ "help": "Adam weight decay",
257
+ "required": False,
258
+ "default": 1e-2,
259
+ "type": float,
260
+ },
261
+ {
262
+ "arg": "--adam-epsilon",
263
+ "help": "Adam epsilon",
264
+ "required": False,
265
+ "default": 1e-8,
266
+ "type": float,
267
+ },
268
+ {
269
+ "arg": "--max-grad-norm",
270
+ "help": "Max grad norm",
271
+ "required": False,
272
+ "default": 1.0,
273
+ "type": float,
274
+ },
275
+ {
276
+ "arg": "--allow-tf32",
277
+ "help": "Allow TF32",
278
+ "required": False,
279
+ "action": "store_true",
280
+ },
281
+ {
282
+ "arg": "--prior-generation-precision",
283
+ "help": "Prior generation precision",
284
+ "required": False,
285
+ "type": str,
286
+ },
287
+ {
288
+ "arg": "--local-rank",
289
+ "help": "Local rank",
290
+ "required": False,
291
+ "default": -1,
292
+ "type": int,
293
+ },
294
+ {
295
+ "arg": "--xformers",
296
+ "help": "Enable xformers memory efficient attention",
297
+ "required": False,
298
+ "action": "store_true",
299
+ },
300
+ {
301
+ "arg": "--pre-compute-text-embeddings",
302
+ "help": "Pre compute text embeddings",
303
+ "required": False,
304
+ "action": "store_true",
305
+ },
306
+ {
307
+ "arg": "--tokenizer-max-length",
308
+ "help": "Tokenizer max length",
309
+ "required": False,
310
+ "type": int,
311
+ },
312
+ {
313
+ "arg": "--text-encoder-use-attention-mask",
314
+ "help": "Text encoder use attention mask",
315
+ "required": False,
316
+ "action": "store_true",
317
+ },
318
+ {
319
+ "arg": "--rank",
320
+ "help": "Rank",
321
+ "required": False,
322
+ "default": 4,
323
+ "type": int,
324
+ },
325
+ {
326
+ "arg": "--xl",
327
+ "help": "XL",
328
+ "required": False,
329
+ "action": "store_true",
330
+ },
331
+ {
332
+ "arg": "--fp16",
333
+ "help": "FP16",
334
+ "required": False,
335
+ "action": "store_true",
336
+ },
337
+ {
338
+ "arg": "--bf16",
339
+ "help": "BF16",
340
+ "required": False,
341
+ "action": "store_true",
342
+ },
343
+ {
344
+ "arg": "--hub-token",
345
+ "help": "Hub token",
346
+ "required": False,
347
+ "type": str,
348
+ },
349
+ {
350
+ "arg": "--hub-model-id",
351
+ "help": "Hub model id",
352
+ "required": False,
353
+ "type": str,
354
+ },
355
+ {
356
+ "arg": "--push-to-hub",
357
+ "help": "Push to hub",
358
+ "required": False,
359
+ "action": "store_true",
360
+ },
361
+ {
362
+ "arg": "--validation-prompt",
363
+ "help": "Validation prompt",
364
+ "required": False,
365
+ "type": str,
366
+ },
367
+ {
368
+ "arg": "--num-validation-images",
369
+ "help": "Number of validation images",
370
+ "required": False,
371
+ "default": 4,
372
+ "type": int,
373
+ },
374
+ {
375
+ "arg": "--validation-epochs",
376
+ "help": "Validation epochs",
377
+ "required": False,
378
+ "default": 50,
379
+ "type": int,
380
+ },
381
+ {
382
+ "arg": "--checkpoints-total-limit",
383
+ "help": "Checkpoints total limit",
384
+ "required": False,
385
+ "type": int,
386
+ },
387
+ {
388
+ "arg": "--validation-images",
389
+ "help": "Validation images",
390
+ "required": False,
391
+ "type": str,
392
+ },
393
+ {
394
+ "arg": "--logging",
395
+ "help": "Logging using tensorboard",
396
+ "required": False,
397
+ "action": "store_true",
398
+ },
399
+ ]
400
+
401
+ run_dreambooth_parser = parser.add_parser("dreambooth", description="✨ Run AutoTrain DreamBooth Training")
402
+ for arg in arg_list:
403
+ if "action" in arg:
404
+ run_dreambooth_parser.add_argument(
405
+ arg["arg"],
406
+ help=arg["help"],
407
+ required=arg.get("required", False),
408
+ action=arg.get("action"),
409
+ default=arg.get("default"),
410
+ )
411
+ else:
412
+ run_dreambooth_parser.add_argument(
413
+ arg["arg"],
414
+ help=arg["help"],
415
+ required=arg.get("required", False),
416
+ type=arg.get("type"),
417
+ default=arg.get("default"),
418
+ )
419
+ run_dreambooth_parser.set_defaults(func=run_dreambooth_command_factory)
420
+
421
+ def __init__(self, args):
422
+ self.args = args
423
+ logger.info(self.args)
424
+
425
+ store_true_arg_names = [
426
+ "center_crop",
427
+ "train_text_encoder",
428
+ "gradient_checkpointing",
429
+ "scale_lr",
430
+ "use_8bit_adam",
431
+ "allow_tf32",
432
+ "xformers",
433
+ "pre_compute_text_embeddings",
434
+ "text_encoder_use_attention_mask",
435
+ "xl",
436
+ "fp16",
437
+ "bf16",
438
+ "push_to_hub",
439
+ "logging",
440
+ "prior_preservation",
441
+ ]
442
+
443
+ for arg_name in store_true_arg_names:
444
+ if getattr(self.args, arg_name) is None:
445
+ setattr(self.args, arg_name, False)
446
+
447
+ if self.args.fp16 and self.args.bf16:
448
+ raise ValueError("❌ Please choose either FP16 or BF16")
449
+
450
+ # check if self.args.image_path is a directory with images
451
+ if not os.path.isdir(self.args.image_path):
452
+ raise ValueError("❌ Please specify a valid image directory")
453
+
454
+ # count the number of images in the directory. valid images are .jpg, .jpeg, .png
455
+ num_images = count_images(self.args.image_path)
456
+ if num_images == 0:
457
+ raise ValueError("❌ Please specify a valid image directory")
458
+
459
+ if self.args.push_to_hub:
460
+ if self.args.hub_model_id is None:
461
+ raise ValueError("❌ Please specify a hub model id")
462
+
463
+ if self.args.model in XL_MODELS:
464
+ self.args.xl = True
465
+
466
+ def run(self):
467
+ logger.info("Running DreamBooth Training")
468
+ params = DreamBoothTrainingParams(**vars(self.args))
469
+ train_dreambooth(params)
autotrain-advanced/src/autotrain/cli/run_llm.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+
3
+ from loguru import logger
4
+
5
+ from autotrain.infer.text_generation import TextGenerationInference
6
+
7
+ from ..trainers.clm import train as train_llm
8
+ from ..trainers.utils import LLMTrainingParams
9
+ from . import BaseAutoTrainCommand
10
+
11
+
12
+ def run_llm_command_factory(args):
13
+ return RunAutoTrainLLMCommand(
14
+ args.train,
15
+ args.deploy,
16
+ args.inference,
17
+ args.data_path,
18
+ args.train_split,
19
+ args.valid_split,
20
+ args.text_column,
21
+ args.model,
22
+ args.learning_rate,
23
+ args.num_train_epochs,
24
+ args.train_batch_size,
25
+ args.eval_batch_size,
26
+ args.warmup_ratio,
27
+ args.gradient_accumulation_steps,
28
+ args.optimizer,
29
+ args.scheduler,
30
+ args.weight_decay,
31
+ args.max_grad_norm,
32
+ args.seed,
33
+ args.add_eos_token,
34
+ args.block_size,
35
+ args.use_peft,
36
+ args.lora_r,
37
+ args.lora_alpha,
38
+ args.lora_dropout,
39
+ args.training_type,
40
+ args.train_on_inputs,
41
+ args.logging_steps,
42
+ args.project_name,
43
+ args.evaluation_strategy,
44
+ args.save_total_limit,
45
+ args.save_strategy,
46
+ args.auto_find_batch_size,
47
+ args.fp16,
48
+ args.push_to_hub,
49
+ args.use_int8,
50
+ args.model_max_length,
51
+ args.repo_id,
52
+ args.use_int4,
53
+ args.trainer,
54
+ args.target_modules,
55
+ )
56
+
57
+
58
+ class RunAutoTrainLLMCommand(BaseAutoTrainCommand):
59
+ @staticmethod
60
+ def register_subcommand(parser: ArgumentParser):
61
+ run_llm_parser = parser.add_parser(
62
+ "llm",
63
+ description="✨ Run AutoTrain LLM training/inference/deployment",
64
+ )
65
+ run_llm_parser.add_argument(
66
+ "--train",
67
+ help="Train the model",
68
+ required=False,
69
+ action="store_true",
70
+ )
71
+ run_llm_parser.add_argument(
72
+ "--deploy",
73
+ help="Deploy the model",
74
+ required=False,
75
+ action="store_true",
76
+ )
77
+ run_llm_parser.add_argument(
78
+ "--inference",
79
+ help="Run inference",
80
+ required=False,
81
+ action="store_true",
82
+ )
83
+ run_llm_parser.add_argument(
84
+ "--data_path",
85
+ help="Train dataset to use",
86
+ required=False,
87
+ type=str,
88
+ )
89
+ run_llm_parser.add_argument(
90
+ "--train_split",
91
+ help="Test dataset split to use",
92
+ required=False,
93
+ type=str,
94
+ default="train",
95
+ )
96
+ run_llm_parser.add_argument(
97
+ "--valid_split",
98
+ help="Validation dataset split to use",
99
+ required=False,
100
+ type=str,
101
+ default=None,
102
+ )
103
+ run_llm_parser.add_argument(
104
+ "--text_column",
105
+ help="Text column to use",
106
+ required=False,
107
+ type=str,
108
+ default="text",
109
+ )
110
+ run_llm_parser.add_argument(
111
+ "--model",
112
+ help="Model to use",
113
+ required=False,
114
+ type=str,
115
+ )
116
+ run_llm_parser.add_argument(
117
+ "--learning_rate",
118
+ help="Learning rate to use",
119
+ required=False,
120
+ type=float,
121
+ default=3e-5,
122
+ )
123
+ run_llm_parser.add_argument(
124
+ "--num_train_epochs",
125
+ help="Number of training epochs to use",
126
+ required=False,
127
+ type=int,
128
+ default=1,
129
+ )
130
+ run_llm_parser.add_argument(
131
+ "--train_batch_size",
132
+ help="Training batch size to use",
133
+ required=False,
134
+ type=int,
135
+ default=2,
136
+ )
137
+ run_llm_parser.add_argument(
138
+ "--eval_batch_size",
139
+ help="Evaluation batch size to use",
140
+ required=False,
141
+ type=int,
142
+ default=4,
143
+ )
144
+ run_llm_parser.add_argument(
145
+ "--warmup_ratio",
146
+ help="Warmup proportion to use",
147
+ required=False,
148
+ type=float,
149
+ default=0.1,
150
+ )
151
+ run_llm_parser.add_argument(
152
+ "--gradient_accumulation_steps",
153
+ help="Gradient accumulation steps to use",
154
+ required=False,
155
+ type=int,
156
+ default=1,
157
+ )
158
+ run_llm_parser.add_argument(
159
+ "--optimizer",
160
+ help="Optimizer to use",
161
+ required=False,
162
+ type=str,
163
+ default="adamw_torch",
164
+ )
165
+ run_llm_parser.add_argument(
166
+ "--scheduler",
167
+ help="Scheduler to use",
168
+ required=False,
169
+ type=str,
170
+ default="linear",
171
+ )
172
+ run_llm_parser.add_argument(
173
+ "--weight_decay",
174
+ help="Weight decay to use",
175
+ required=False,
176
+ type=float,
177
+ default=0.0,
178
+ )
179
+ run_llm_parser.add_argument(
180
+ "--max_grad_norm",
181
+ help="Max gradient norm to use",
182
+ required=False,
183
+ type=float,
184
+ default=1.0,
185
+ )
186
+ run_llm_parser.add_argument(
187
+ "--seed",
188
+ help="Seed to use",
189
+ required=False,
190
+ type=int,
191
+ default=42,
192
+ )
193
+ run_llm_parser.add_argument(
194
+ "--add_eos_token",
195
+ help="Add EOS token to use",
196
+ required=False,
197
+ action="store_true",
198
+ )
199
+ run_llm_parser.add_argument(
200
+ "--block_size",
201
+ help="Block size to use",
202
+ required=False,
203
+ type=int,
204
+ default=-1,
205
+ )
206
+ run_llm_parser.add_argument(
207
+ "--use_peft",
208
+ help="Use PEFT to use",
209
+ required=False,
210
+ action="store_true",
211
+ )
212
+ run_llm_parser.add_argument(
213
+ "--lora_r",
214
+ help="Lora r to use",
215
+ required=False,
216
+ type=int,
217
+ default=16,
218
+ )
219
+ run_llm_parser.add_argument(
220
+ "--lora_alpha",
221
+ help="Lora alpha to use",
222
+ required=False,
223
+ type=int,
224
+ default=32,
225
+ )
226
+ run_llm_parser.add_argument(
227
+ "--lora_dropout",
228
+ help="Lora dropout to use",
229
+ required=False,
230
+ type=float,
231
+ default=0.05,
232
+ )
233
+ run_llm_parser.add_argument(
234
+ "--training_type",
235
+ help="Training type to use",
236
+ required=False,
237
+ type=str,
238
+ default="generic",
239
+ )
240
+ run_llm_parser.add_argument(
241
+ "--train_on_inputs",
242
+ help="Train on inputs to use",
243
+ required=False,
244
+ action="store_true",
245
+ )
246
+ run_llm_parser.add_argument(
247
+ "--logging_steps",
248
+ help="Logging steps to use",
249
+ required=False,
250
+ type=int,
251
+ default=-1,
252
+ )
253
+ run_llm_parser.add_argument(
254
+ "--project_name",
255
+ help="Output directory",
256
+ required=False,
257
+ type=str,
258
+ )
259
+ run_llm_parser.add_argument(
260
+ "--evaluation_strategy",
261
+ help="Evaluation strategy to use",
262
+ required=False,
263
+ type=str,
264
+ default="epoch",
265
+ )
266
+ run_llm_parser.add_argument(
267
+ "--save_total_limit",
268
+ help="Save total limit to use",
269
+ required=False,
270
+ type=int,
271
+ default=1,
272
+ )
273
+ run_llm_parser.add_argument(
274
+ "--save_strategy",
275
+ help="Save strategy to use",
276
+ required=False,
277
+ type=str,
278
+ default="epoch",
279
+ )
280
+ run_llm_parser.add_argument(
281
+ "--auto_find_batch_size",
282
+ help="Auto find batch size True/False",
283
+ required=False,
284
+ action="store_true",
285
+ )
286
+ run_llm_parser.add_argument(
287
+ "--fp16",
288
+ help="FP16 True/False",
289
+ required=False,
290
+ action="store_true",
291
+ )
292
+ run_llm_parser.add_argument(
293
+ "--push_to_hub",
294
+ help="Push to hub True/False",
295
+ required=False,
296
+ action="store_true",
297
+ )
298
+ run_llm_parser.add_argument(
299
+ "--use_int8",
300
+ help="Use int8 True/False",
301
+ required=False,
302
+ action="store_true",
303
+ )
304
+ run_llm_parser.add_argument(
305
+ "--model_max_length",
306
+ help="Model max length to use",
307
+ required=False,
308
+ type=int,
309
+ default=1024,
310
+ )
311
+ run_llm_parser.add_argument(
312
+ "--repo_id",
313
+ help="Repo id for hugging face hub",
314
+ required=False,
315
+ type=str,
316
+ )
317
+ run_llm_parser.add_argument(
318
+ "--use_int4",
319
+ help="Use int4 True/False",
320
+ required=False,
321
+ action="store_true",
322
+ )
323
+ run_llm_parser.add_argument(
324
+ "--trainer",
325
+ help="Trainer type to use",
326
+ required=False,
327
+ type=str,
328
+ default="default",
329
+ )
330
+ run_llm_parser.add_argument(
331
+ "--target_modules",
332
+ help="Target modules to use",
333
+ required=False,
334
+ type=str,
335
+ default=None,
336
+ )
337
+
338
+ run_llm_parser.set_defaults(func=run_llm_command_factory)
339
+
340
+ def __init__(
341
+ self,
342
+ train,
343
+ deploy,
344
+ inference,
345
+ data_path,
346
+ train_split,
347
+ valid_split,
348
+ text_column,
349
+ model,
350
+ learning_rate,
351
+ num_train_epochs,
352
+ train_batch_size,
353
+ eval_batch_size,
354
+ warmup_ratio,
355
+ gradient_accumulation_steps,
356
+ optimizer,
357
+ scheduler,
358
+ weight_decay,
359
+ max_grad_norm,
360
+ seed,
361
+ add_eos_token,
362
+ block_size,
363
+ use_peft,
364
+ lora_r,
365
+ lora_alpha,
366
+ lora_dropout,
367
+ training_type,
368
+ train_on_inputs,
369
+ logging_steps,
370
+ project_name,
371
+ evaluation_strategy,
372
+ save_total_limit,
373
+ save_strategy,
374
+ auto_find_batch_size,
375
+ fp16,
376
+ push_to_hub,
377
+ use_int8,
378
+ model_max_length,
379
+ repo_id,
380
+ use_int4,
381
+ trainer,
382
+ target_modules,
383
+ ):
384
+ self.train = train
385
+ self.deploy = deploy
386
+ self.inference = inference
387
+ self.data_path = data_path
388
+ self.train_split = train_split
389
+ self.valid_split = valid_split
390
+ self.text_column = text_column
391
+ self.model = model
392
+ self.learning_rate = learning_rate
393
+ self.num_train_epochs = num_train_epochs
394
+ self.train_batch_size = train_batch_size
395
+ self.eval_batch_size = eval_batch_size
396
+ self.warmup_ratio = warmup_ratio
397
+ self.gradient_accumulation_steps = gradient_accumulation_steps
398
+ self.optimizer = optimizer
399
+ self.scheduler = scheduler
400
+ self.weight_decay = weight_decay
401
+ self.max_grad_norm = max_grad_norm
402
+ self.seed = seed
403
+ self.add_eos_token = add_eos_token
404
+ self.block_size = block_size
405
+ self.use_peft = use_peft
406
+ self.lora_r = lora_r
407
+ self.lora_alpha = lora_alpha
408
+ self.lora_dropout = lora_dropout
409
+ self.training_type = training_type
410
+ self.train_on_inputs = train_on_inputs
411
+ self.logging_steps = logging_steps
412
+ self.project_name = project_name
413
+ self.evaluation_strategy = evaluation_strategy
414
+ self.save_total_limit = save_total_limit
415
+ self.save_strategy = save_strategy
416
+ self.auto_find_batch_size = auto_find_batch_size
417
+ self.fp16 = fp16
418
+ self.push_to_hub = push_to_hub
419
+ self.use_int8 = use_int8
420
+ self.model_max_length = model_max_length
421
+ self.repo_id = repo_id
422
+ self.use_int4 = use_int4
423
+ self.trainer = trainer
424
+ self.target_modules = target_modules
425
+
426
+ if self.train:
427
+ if self.project_name is None:
428
+ raise ValueError("Project name must be specified")
429
+ if self.data_path is None:
430
+ raise ValueError("Data path must be specified")
431
+ if self.model is None:
432
+ raise ValueError("Model must be specified")
433
+ if self.push_to_hub:
434
+ if self.repo_id is None:
435
+ raise ValueError("Repo id must be specified for push to hub")
436
+
437
+ if self.inference:
438
+ tgi = TextGenerationInference(self.project_name, use_int4=self.use_int4, use_int8=self.use_int8)
439
+ while True:
440
+ prompt = input("User: ")
441
+ if prompt == "exit()":
442
+ break
443
+ print(f"Bot: {tgi.chat(prompt)}")
444
+
445
+ def run(self):
446
+ logger.info("Running LLM")
447
+ logger.info(f"Train: {self.train}")
448
+ if self.train:
449
+ params = LLMTrainingParams(
450
+ model_name=self.model,
451
+ data_path=self.data_path,
452
+ train_split=self.train_split,
453
+ valid_split=self.valid_split,
454
+ text_column=self.text_column,
455
+ learning_rate=self.learning_rate,
456
+ num_train_epochs=self.num_train_epochs,
457
+ train_batch_size=self.train_batch_size,
458
+ eval_batch_size=self.eval_batch_size,
459
+ warmup_ratio=self.warmup_ratio,
460
+ gradient_accumulation_steps=self.gradient_accumulation_steps,
461
+ optimizer=self.optimizer,
462
+ scheduler=self.scheduler,
463
+ weight_decay=self.weight_decay,
464
+ max_grad_norm=self.max_grad_norm,
465
+ seed=self.seed,
466
+ add_eos_token=self.add_eos_token,
467
+ block_size=self.block_size,
468
+ use_peft=self.use_peft,
469
+ lora_r=self.lora_r,
470
+ lora_alpha=self.lora_alpha,
471
+ lora_dropout=self.lora_dropout,
472
+ training_type=self.training_type,
473
+ train_on_inputs=self.train_on_inputs,
474
+ logging_steps=self.logging_steps,
475
+ project_name=self.project_name,
476
+ evaluation_strategy=self.evaluation_strategy,
477
+ save_total_limit=self.save_total_limit,
478
+ save_strategy=self.save_strategy,
479
+ auto_find_batch_size=self.auto_find_batch_size,
480
+ fp16=self.fp16,
481
+ push_to_hub=self.push_to_hub,
482
+ use_int8=self.use_int8,
483
+ model_max_length=self.model_max_length,
484
+ repo_id=self.repo_id,
485
+ use_int4=self.use_int4,
486
+ trainer=self.trainer,
487
+ target_modules=self.target_modules,
488
+ )
489
+ train_llm(params)
autotrain-advanced/src/autotrain/cli/run_setup.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from argparse import ArgumentParser
3
+
4
+ from loguru import logger
5
+
6
+ from . import BaseAutoTrainCommand
7
+
8
+
9
+ def run_app_command_factory(args):
10
+ return RunSetupCommand(args.update_torch)
11
+
12
+
13
+ class RunSetupCommand(BaseAutoTrainCommand):
14
+ @staticmethod
15
+ def register_subcommand(parser: ArgumentParser):
16
+ run_setup_parser = parser.add_parser(
17
+ "setup",
18
+ description="✨ Run AutoTrain setup",
19
+ )
20
+ run_setup_parser.add_argument(
21
+ "--update-torch",
22
+ action="store_true",
23
+ help="Update PyTorch to latest version",
24
+ )
25
+ run_setup_parser.set_defaults(func=run_app_command_factory)
26
+
27
+ def __init__(self, update_torch: bool):
28
+ self.update_torch = update_torch
29
+
30
+ def run(self):
31
+ # install latest transformers
32
+ cmd = "pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.git"
33
+ pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
34
+ logger.info("Installing latest transformers@main")
35
+ _, _ = pipe.communicate()
36
+ logger.info("Successfully installed latest transformers")
37
+
38
+ cmd = "pip uninstall -y peft && pip install git+https://github.com/huggingface/peft.git"
39
+ pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
40
+ logger.info("Installing latest peft@main")
41
+ _, _ = pipe.communicate()
42
+ logger.info("Successfully installed latest peft")
43
+
44
+ cmd = "pip uninstall -y diffusers && pip install git+https://github.com/huggingface/diffusers.git"
45
+ pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
46
+ logger.info("Installing latest diffusers@main")
47
+ _, _ = pipe.communicate()
48
+ logger.info("Successfully installed latest diffusers")
49
+
50
+ cmd = "pip uninstall -y trl && pip install git+https://github.com/lvwerra/trl.git"
51
+ pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
52
+ logger.info("Installing latest trl@main")
53
+ _, _ = pipe.communicate()
54
+ logger.info("Successfully installed latest trl")
55
+
56
+ if self.update_torch:
57
+ cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
58
+ pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
59
+ logger.info("Installing latest PyTorch")
60
+ _, _ = pipe.communicate()
61
+ logger.info("Successfully installed latest PyTorch")
autotrain-advanced/src/autotrain/config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from loguru import logger
5
+
6
+
7
+ AUTOTRAIN_BACKEND_API = os.getenv("AUTOTRAIN_BACKEND_API", "https://api.autotrain.huggingface.co")
8
+
9
+ HF_API = os.getenv("HF_API", "https://huggingface.co")
10
+
11
+
12
+ logger.configure(handlers=[dict(sink=sys.stderr, format="> <level>{level:<7} {message}</level>")])
autotrain-advanced/src/autotrain/dataset.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import zipfile
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import pandas as pd
8
+ from loguru import logger
9
+
10
+ from autotrain.preprocessor.dreambooth import DreamboothPreprocessor
11
+ from autotrain.preprocessor.tabular import (
12
+ TabularBinaryClassificationPreprocessor,
13
+ TabularMultiClassClassificationPreprocessor,
14
+ TabularSingleColumnRegressionPreprocessor,
15
+ )
16
+ from autotrain.preprocessor.text import (
17
+ LLMPreprocessor,
18
+ TextBinaryClassificationPreprocessor,
19
+ TextMultiClassClassificationPreprocessor,
20
+ TextSingleColumnRegressionPreprocessor,
21
+ )
22
+ from autotrain.preprocessor.vision import ImageClassificationPreprocessor
23
+
24
+
25
+ def remove_non_image_files(folder):
26
+ # Define allowed image file extensions
27
+ allowed_extensions = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
28
+
29
+ # Iterate through all files in the folder
30
+ for root, dirs, files in os.walk(folder):
31
+ for file in files:
32
+ # Get the file extension
33
+ file_extension = os.path.splitext(file)[1]
34
+
35
+ # If the file extension is not in the allowed list, remove the file
36
+ if file_extension.lower() not in allowed_extensions:
37
+ file_path = os.path.join(root, file)
38
+ os.remove(file_path)
39
+ print(f"Removed file: {file_path}")
40
+
41
+ # Recursively call the function on each subfolder
42
+ for subfolder in dirs:
43
+ remove_non_image_files(os.path.join(root, subfolder))
44
+
45
+
46
+ @dataclass
47
+ class AutoTrainDreamboothDataset:
48
+ concept_images: List[Any]
49
+ concept_name: str
50
+ token: str
51
+ project_name: str
52
+ username: str
53
+
54
+ def __str__(self) -> str:
55
+ info = f"Dataset: {self.project_name} ({self.task})\n"
56
+ return info
57
+
58
+ def __post_init__(self):
59
+ self.task = "dreambooth"
60
+ logger.info(self.__str__())
61
+
62
+ @property
63
+ def num_samples(self):
64
+ return len(self.concept_images)
65
+
66
+ def prepare(self):
67
+ preprocessor = DreamboothPreprocessor(
68
+ concept_images=self.concept_images,
69
+ concept_name=self.concept_name,
70
+ token=self.token,
71
+ project_name=self.project_name,
72
+ username=self.username,
73
+ )
74
+ preprocessor.prepare()
75
+
76
+
77
+ @dataclass
78
+ class AutoTrainImageClassificationDataset:
79
+ train_data: str
80
+ token: str
81
+ project_name: str
82
+ username: str
83
+ valid_data: Optional[str] = None
84
+ percent_valid: Optional[float] = None
85
+
86
+ def __str__(self) -> str:
87
+ info = f"Dataset: {self.project_name} ({self.task})\n"
88
+ info += f"Train data: {self.train_data}\n"
89
+ info += f"Valid data: {self.valid_data}\n"
90
+ return info
91
+
92
+ def __post_init__(self):
93
+ self.task = "image_multi_class_classification"
94
+ if not self.valid_data and self.percent_valid is None:
95
+ self.percent_valid = 0.2
96
+ elif self.valid_data and self.percent_valid is not None:
97
+ raise ValueError("You can only specify one of valid_data or percent_valid")
98
+ elif self.valid_data:
99
+ self.percent_valid = 0.0
100
+ logger.info(self.__str__())
101
+
102
+ self.num_files = self._count_files()
103
+
104
+ @property
105
+ def num_samples(self):
106
+ return self.num_files
107
+
108
+ def _count_files(self):
109
+ num_files = 0
110
+ zip_ref = zipfile.ZipFile(self.train_data, "r")
111
+ for _ in zip_ref.namelist():
112
+ num_files += 1
113
+ if self.valid_data:
114
+ zip_ref = zipfile.ZipFile(self.valid_data, "r")
115
+ for _ in zip_ref.namelist():
116
+ num_files += 1
117
+ return num_files
118
+
119
+ def prepare(self):
120
+ cache_dir = os.environ.get("HF_HOME")
121
+ if not cache_dir:
122
+ cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
123
+
124
+ random_uuid = uuid.uuid4()
125
+ train_dir = os.path.join(cache_dir, "autotrain", str(random_uuid))
126
+ os.makedirs(train_dir, exist_ok=True)
127
+ zip_ref = zipfile.ZipFile(self.train_data, "r")
128
+ zip_ref.extractall(train_dir)
129
+ # remove the __MACOSX directory
130
+ macosx_dir = os.path.join(train_dir, "__MACOSX")
131
+ if os.path.exists(macosx_dir):
132
+ os.system(f"rm -rf {macosx_dir}")
133
+ remove_non_image_files(train_dir)
134
+
135
+ valid_dir = None
136
+ if self.valid_data:
137
+ random_uuid = uuid.uuid4()
138
+ valid_dir = os.path.join(cache_dir, "autotrain", str(random_uuid))
139
+ os.makedirs(valid_dir, exist_ok=True)
140
+ zip_ref = zipfile.ZipFile(self.valid_data, "r")
141
+ zip_ref.extractall(valid_dir)
142
+ # remove the __MACOSX directory
143
+ macosx_dir = os.path.join(valid_dir, "__MACOSX")
144
+ if os.path.exists(macosx_dir):
145
+ os.system(f"rm -rf {macosx_dir}")
146
+ remove_non_image_files(valid_dir)
147
+
148
+ preprocessor = ImageClassificationPreprocessor(
149
+ train_data=train_dir,
150
+ valid_data=valid_dir,
151
+ token=self.token,
152
+ project_name=self.project_name,
153
+ username=self.username,
154
+ )
155
+ preprocessor.prepare()
156
+
157
+
158
+ @dataclass
159
+ class AutoTrainDataset:
160
+ train_data: List[str]
161
+ task: str
162
+ token: str
163
+ project_name: str
164
+ username: str
165
+ column_mapping: Optional[Dict[str, str]] = None
166
+ valid_data: Optional[List[str]] = None
167
+ percent_valid: Optional[float] = None
168
+
169
+ def __str__(self) -> str:
170
+ info = f"Dataset: {self.project_name} ({self.task})\n"
171
+ info += f"Train data: {self.train_data}\n"
172
+ info += f"Valid data: {self.valid_data}\n"
173
+ info += f"Column mapping: {self.column_mapping}\n"
174
+ return info
175
+
176
+ def __post_init__(self):
177
+ if not self.valid_data and self.percent_valid is None:
178
+ self.percent_valid = 0.2
179
+ elif self.valid_data and self.percent_valid is not None:
180
+ raise ValueError("You can only specify one of valid_data or percent_valid")
181
+ elif self.valid_data:
182
+ self.percent_valid = 0.0
183
+
184
+ self.train_df, self.valid_df = self._preprocess_data()
185
+ logger.info(self.__str__())
186
+
187
+ def _preprocess_data(self):
188
+ train_df = []
189
+ for file in self.train_data:
190
+ if isinstance(file, pd.DataFrame):
191
+ train_df.append(file)
192
+ else:
193
+ train_df.append(pd.read_csv(file))
194
+ if len(train_df) > 1:
195
+ train_df = pd.concat(train_df)
196
+ else:
197
+ train_df = train_df[0]
198
+
199
+ valid_df = None
200
+ if len(self.valid_data) > 0:
201
+ valid_df = []
202
+ for file in self.valid_data:
203
+ if isinstance(file, pd.DataFrame):
204
+ valid_df.append(file)
205
+ else:
206
+ valid_df.append(pd.read_csv(file))
207
+ if len(valid_df) > 1:
208
+ valid_df = pd.concat(valid_df)
209
+ else:
210
+ valid_df = valid_df[0]
211
+ return train_df, valid_df
212
+
213
+ @property
214
+ def num_samples(self):
215
+ return len(self.train_df) + len(self.valid_df) if self.valid_df is not None else len(self.train_df)
216
+
217
+ def prepare(self):
218
+ if self.task == "text_binary_classification":
219
+ text_column = self.column_mapping["text"]
220
+ label_column = self.column_mapping["label"]
221
+ preprocessor = TextBinaryClassificationPreprocessor(
222
+ train_data=self.train_df,
223
+ text_column=text_column,
224
+ label_column=label_column,
225
+ username=self.username,
226
+ project_name=self.project_name,
227
+ valid_data=self.valid_df,
228
+ test_size=self.percent_valid,
229
+ token=self.token,
230
+ seed=42,
231
+ )
232
+ preprocessor.prepare()
233
+
234
+ elif self.task == "text_multi_class_classification":
235
+ text_column = self.column_mapping["text"]
236
+ label_column = self.column_mapping["label"]
237
+ preprocessor = TextMultiClassClassificationPreprocessor(
238
+ train_data=self.train_df,
239
+ text_column=text_column,
240
+ label_column=label_column,
241
+ username=self.username,
242
+ project_name=self.project_name,
243
+ valid_data=self.valid_df,
244
+ test_size=self.percent_valid,
245
+ token=self.token,
246
+ seed=42,
247
+ )
248
+ preprocessor.prepare()
249
+
250
+ elif self.task == "text_single_column_regression":
251
+ text_column = self.column_mapping["text"]
252
+ label_column = self.column_mapping["label"]
253
+ preprocessor = TextSingleColumnRegressionPreprocessor(
254
+ train_data=self.train_df,
255
+ text_column=text_column,
256
+ label_column=label_column,
257
+ username=self.username,
258
+ project_name=self.project_name,
259
+ valid_data=self.valid_df,
260
+ test_size=self.percent_valid,
261
+ token=self.token,
262
+ seed=42,
263
+ )
264
+ preprocessor.prepare()
265
+
266
+ elif self.task == "lm_training":
267
+ text_column = self.column_mapping.get("text", None)
268
+ if text_column is None:
269
+ prompt_column = self.column_mapping["prompt"]
270
+ response_column = self.column_mapping["response"]
271
+ else:
272
+ prompt_column = None
273
+ response_column = None
274
+ context_column = self.column_mapping.get("context", None)
275
+ prompt_start_column = self.column_mapping.get("prompt_start", None)
276
+ preprocessor = LLMPreprocessor(
277
+ train_data=self.train_df,
278
+ text_column=text_column,
279
+ prompt_column=prompt_column,
280
+ response_column=response_column,
281
+ context_column=context_column,
282
+ prompt_start_column=prompt_start_column,
283
+ username=self.username,
284
+ project_name=self.project_name,
285
+ valid_data=self.valid_df,
286
+ test_size=self.percent_valid,
287
+ token=self.token,
288
+ seed=42,
289
+ )
290
+ preprocessor.prepare()
291
+
292
+ elif self.task == "tabular_binary_classification":
293
+ id_column = self.column_mapping["id"]
294
+ label_column = self.column_mapping["label"]
295
+ if len(id_column.strip()) == 0:
296
+ id_column = None
297
+ preprocessor = TabularBinaryClassificationPreprocessor(
298
+ train_data=self.train_df,
299
+ id_column=id_column,
300
+ label_column=label_column,
301
+ username=self.username,
302
+ project_name=self.project_name,
303
+ valid_data=self.valid_df,
304
+ test_size=self.percent_valid,
305
+ token=self.token,
306
+ seed=42,
307
+ )
308
+ preprocessor.prepare()
309
+ elif self.task == "tabular_multi_class_classification":
310
+ id_column = self.column_mapping["id"]
311
+ label_column = self.column_mapping["label"]
312
+ if len(id_column.strip()) == 0:
313
+ id_column = None
314
+ preprocessor = TabularMultiClassClassificationPreprocessor(
315
+ train_data=self.train_df,
316
+ id_column=id_column,
317
+ label_column=label_column,
318
+ username=self.username,
319
+ project_name=self.project_name,
320
+ valid_data=self.valid_df,
321
+ test_size=self.percent_valid,
322
+ token=self.token,
323
+ seed=42,
324
+ )
325
+ preprocessor.prepare()
326
+ elif self.task == "tabular_single_column_regression":
327
+ id_column = self.column_mapping["id"]
328
+ label_column = self.column_mapping["label"]
329
+ if len(id_column.strip()) == 0:
330
+ id_column = None
331
+ preprocessor = TabularSingleColumnRegressionPreprocessor(
332
+ train_data=self.train_df,
333
+ id_column=id_column,
334
+ label_column=label_column,
335
+ username=self.username,
336
+ project_name=self.project_name,
337
+ valid_data=self.valid_df,
338
+ test_size=self.percent_valid,
339
+ token=self.token,
340
+ seed=42,
341
+ )
342
+ preprocessor.prepare()
343
+ else:
344
+ raise ValueError(f"Task {self.task} not supported")
autotrain-advanced/src/autotrain/dreambooth_app.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pty
3
+ import random
4
+ import shutil
5
+ import string
6
+ import subprocess
7
+
8
+ import gradio as gr
9
+ from huggingface_hub import HfApi, whoami
10
+
11
+
12
+ # ❯ autotrain dreambooth --help
13
+ # usage: autotrain <command> [<args>] dreambooth [-h] --model MODEL [--revision REVISION] [--tokenizer TOKENIZER] --image-path IMAGE_PATH
14
+ # [--class-image-path CLASS_IMAGE_PATH] --prompt PROMPT [--class-prompt CLASS_PROMPT]
15
+ # [--num-class-images NUM_CLASS_IMAGES] [--class-labels-conditioning CLASS_LABELS_CONDITIONING]
16
+ # [--prior-preservation] [--prior-loss-weight PRIOR_LOSS_WEIGHT] --output OUTPUT [--seed SEED]
17
+ # --resolution RESOLUTION [--center-crop] [--train-text-encoder] [--batch-size BATCH_SIZE]
18
+ # [--sample-batch-size SAMPLE_BATCH_SIZE] [--epochs EPOCHS] [--num-steps NUM_STEPS]
19
+ # [--checkpointing-steps CHECKPOINTING_STEPS] [--resume-from-checkpoint RESUME_FROM_CHECKPOINT]
20
+ # [--gradient-accumulation GRADIENT_ACCUMULATION] [--gradient-checkpointing] [--lr LR] [--scale-lr]
21
+ # [--scheduler SCHEDULER] [--warmup-steps WARMUP_STEPS] [--num-cycles NUM_CYCLES] [--lr-power LR_POWER]
22
+ # [--dataloader-num-workers DATALOADER_NUM_WORKERS] [--use-8bit-adam] [--adam-beta1 ADAM_BETA1]
23
+ # [--adam-beta2 ADAM_BETA2] [--adam-weight-decay ADAM_WEIGHT_DECAY] [--adam-epsilon ADAM_EPSILON]
24
+ # [--max-grad-norm MAX_GRAD_NORM] [--allow-tf32]
25
+ # [--prior-generation-precision PRIOR_GENERATION_PRECISION] [--local-rank LOCAL_RANK] [--xformers]
26
+ # [--pre-compute-text-embeddings] [--tokenizer-max-length TOKENIZER_MAX_LENGTH]
27
+ # [--text-encoder-use-attention-mask] [--rank RANK] [--xl] [--fp16] [--bf16] [--hub-token HUB_TOKEN]
28
+ # [--hub-model-id HUB_MODEL_ID] [--push-to-hub] [--validation-prompt VALIDATION_PROMPT]
29
+ # [--num-validation-images NUM_VALIDATION_IMAGES] [--validation-epochs VALIDATION_EPOCHS]
30
+ # [--checkpoints-total-limit CHECKPOINTS_TOTAL_LIMIT] [--validation-images VALIDATION_IMAGES]
31
+ # [--logging]
32
+
33
+ REPO_ID = os.environ.get("REPO_ID")
34
+ ALLOWED_FILE_TYPES = ["png", "jpg", "jpeg"]
35
+ MODELS = [
36
+ "stabilityai/stable-diffusion-xl-base-1.0",
37
+ "runwayml/stable-diffusion-v1-5",
38
+ "stabilityai/stable-diffusion-2-1",
39
+ "stabilityai/stable-diffusion-2-1-base",
40
+ ]
41
+ WELCOME_TEXT = """
42
+ Welcome to the AutoTrain DreamBooth! This app allows you to train a DreamBooth model using AutoTrain.
43
+ The app runs on HuggingFace Spaces. Your data is not stored anywhere.
44
+ The trained model (LoRA) will be pushed to your HuggingFace Hub account.
45
+
46
+ You need to use your HuggingFace Hub write [token](https://huggingface.co/settings/tokens) to push the model to your account.
47
+
48
+ NOTE: This space requires GPU to train. Please make sure you have GPU enabled in space settings.
49
+ Please make sure to shutdown / pause the space to avoid any additional charges.
50
+ """
51
+
52
+ STEPS = """
53
+ 1. [Duplicate](https://huggingface.co/spaces/autotrain-projects/dreambooth?duplicate=true) this space
54
+ 2. Upgrade the space to GPU
55
+ 3. Enter your HuggingFace Hub write token
56
+ 4. Upload images and adjust prompt (remember the prompt!)
57
+ 5. Click on Train and wait for the training to finish
58
+ 6. Go to your HuggingFace Hub account to find the trained model
59
+
60
+ NOTE: For any issues or feature requests, please open an issue [here](https://github.com/huggingface/autotrain-advanced/issues)
61
+ """
62
+
63
+
64
+ def _update_project_name():
65
+ random_project_name = "-".join(
66
+ ["".join(random.choices(string.ascii_lowercase + string.digits, k=4)) for _ in range(3)]
67
+ )
68
+ # check if training tracker exists
69
+ if os.path.exists(os.path.join("/tmp", "training")):
70
+ return [
71
+ gr.Text.update(value=random_project_name, visible=True, interactive=True),
72
+ gr.Button.update(interactive=False),
73
+ ]
74
+ return [
75
+ gr.Text.update(value=random_project_name, visible=True, interactive=True),
76
+ gr.Button.update(interactive=True),
77
+ ]
78
+
79
+
80
+ def run_command(cmd):
81
+ cmd = [str(c) for c in cmd]
82
+ print(f"Running command: {' '.join(cmd)}")
83
+ master, slave = pty.openpty()
84
+ p = subprocess.Popen(cmd, stdout=slave, stderr=slave)
85
+ os.close(slave)
86
+
87
+ while p.poll() is None:
88
+ try:
89
+ output = os.read(master, 1024).decode()
90
+ except OSError:
91
+ # Handle exception here, e.g. the pty was closed
92
+ break
93
+ else:
94
+ print(output, end="")
95
+
96
+
97
+ def _run_training(
98
+ hub_token,
99
+ project_name,
100
+ model,
101
+ images,
102
+ prompt,
103
+ learning_rate,
104
+ num_steps,
105
+ batch_size,
106
+ gradient_accumulation_steps,
107
+ prior_preservation,
108
+ scale_lr,
109
+ use_8bit_adam,
110
+ train_text_encoder,
111
+ gradient_checkpointing,
112
+ center_crop,
113
+ prior_loss_weight,
114
+ num_cycles,
115
+ lr_power,
116
+ adam_beta1,
117
+ adam_beta2,
118
+ adam_weight_decay,
119
+ adam_epsilon,
120
+ max_grad_norm,
121
+ warmup_steps,
122
+ scheduler,
123
+ resolution,
124
+ fp16,
125
+ ):
126
+ if REPO_ID == "autotrain-projects/dreambooth":
127
+ return gr.Markdown.update(
128
+ value="❌ Please [duplicate](https://huggingface.co/spaces/autotrain-projects/dreambooth?duplicate=true) this space before training."
129
+ )
130
+
131
+ api = HfApi(token=hub_token)
132
+
133
+ if os.path.exists(os.path.join("/tmp", "training")):
134
+ return gr.Markdown.update(value="❌ Another training job is already running in this space.")
135
+
136
+ with open(os.path.join("/tmp", "training"), "w") as f:
137
+ f.write("training")
138
+
139
+ hub_model_id = whoami(token=hub_token)["name"] + "/" + str(project_name).strip()
140
+
141
+ image_path = "/tmp/data"
142
+ os.makedirs(image_path, exist_ok=True)
143
+ output_dir = "/tmp/model"
144
+ os.makedirs(output_dir, exist_ok=True)
145
+
146
+ for image in images:
147
+ shutil.copy(image.name, image_path)
148
+ cmd = [
149
+ "autotrain",
150
+ "dreambooth",
151
+ "--model",
152
+ model,
153
+ "--output",
154
+ output_dir,
155
+ "--image-path",
156
+ image_path,
157
+ "--prompt",
158
+ prompt,
159
+ "--resolution",
160
+ "1024",
161
+ "--batch-size",
162
+ batch_size,
163
+ "--num-steps",
164
+ num_steps,
165
+ "--gradient-accumulation",
166
+ gradient_accumulation_steps,
167
+ "--lr",
168
+ learning_rate,
169
+ "--scheduler",
170
+ scheduler,
171
+ "--warmup-steps",
172
+ warmup_steps,
173
+ "--num-cycles",
174
+ num_cycles,
175
+ "--lr-power",
176
+ lr_power,
177
+ "--adam-beta1",
178
+ adam_beta1,
179
+ "--adam-beta2",
180
+ adam_beta2,
181
+ "--adam-weight-decay",
182
+ adam_weight_decay,
183
+ "--adam-epsilon",
184
+ adam_epsilon,
185
+ "--max-grad-norm",
186
+ max_grad_norm,
187
+ "--prior-loss-weight",
188
+ prior_loss_weight,
189
+ "--push-to-hub",
190
+ "--hub-token",
191
+ hub_token,
192
+ "--hub-model-id",
193
+ hub_model_id,
194
+ ]
195
+
196
+ if prior_preservation:
197
+ cmd.append("--prior-preservation")
198
+ if scale_lr:
199
+ cmd.append("--scale-lr")
200
+ if use_8bit_adam:
201
+ cmd.append("--use-8bit-adam")
202
+ if train_text_encoder:
203
+ cmd.append("--train-text-encoder")
204
+ if gradient_checkpointing:
205
+ cmd.append("--gradient-checkpointing")
206
+ if center_crop:
207
+ cmd.append("--center-crop")
208
+ if fp16:
209
+ cmd.append("--fp16")
210
+
211
+ try:
212
+ run_command(cmd)
213
+ # delete the training tracker file in /tmp/
214
+ os.remove(os.path.join("/tmp", "training"))
215
+ # switch off space
216
+ if REPO_ID is not None:
217
+ api.pause_space(repo_id=REPO_ID)
218
+ return gr.Markdown.update(value=f"✅ Training finished! Model pushed to {hub_model_id}")
219
+ except Exception as e:
220
+ print(e)
221
+ print("Error running command")
222
+ # delete the training tracker file in /tmp/
223
+ os.remove(os.path.join("/tmp", "training"))
224
+ return gr.Markdown.update(value="❌ Error running command. Please try again.")
225
+
226
+
227
+ def main():
228
+ with gr.Blocks(theme="freddyaboulton/dracula_revamped") as demo:
229
+ gr.Markdown("## 🤗 AutoTrain DreamBooth")
230
+ gr.Markdown(WELCOME_TEXT)
231
+ with gr.Accordion("Steps", open=False):
232
+ gr.Markdown(STEPS)
233
+ hub_token = gr.Textbox(
234
+ label="Hub Token",
235
+ value="",
236
+ lines=1,
237
+ max_lines=1,
238
+ interactive=True,
239
+ type="password",
240
+ )
241
+
242
+ with gr.Row():
243
+ with gr.Column():
244
+ project_name = gr.Textbox(
245
+ label="Project name",
246
+ value="",
247
+ lines=1,
248
+ max_lines=1,
249
+ interactive=True,
250
+ )
251
+ model = gr.Dropdown(
252
+ label="Model",
253
+ choices=MODELS,
254
+ value=MODELS[0],
255
+ visible=True,
256
+ interactive=True,
257
+ elem_id="model",
258
+ allow_custom_values=True,
259
+ )
260
+ images = gr.File(
261
+ label="Images",
262
+ file_types=ALLOWED_FILE_TYPES,
263
+ file_count="multiple",
264
+ visible=True,
265
+ interactive=True,
266
+ )
267
+
268
+ with gr.Column():
269
+ prompt = gr.Textbox(
270
+ label="Prompt",
271
+ placeholder="photo of sks dog",
272
+ lines=1,
273
+ )
274
+ with gr.Row():
275
+ learning_rate = gr.Number(
276
+ label="Learning Rate",
277
+ value=1e-4,
278
+ visible=True,
279
+ interactive=True,
280
+ elem_id="learning_rate",
281
+ )
282
+ num_steps = gr.Number(
283
+ label="Number of Steps",
284
+ value=500,
285
+ visible=True,
286
+ interactive=True,
287
+ elem_id="num_steps",
288
+ precision=0,
289
+ )
290
+ batch_size = gr.Number(
291
+ label="Batch Size",
292
+ value=1,
293
+ visible=True,
294
+ interactive=True,
295
+ elem_id="batch_size",
296
+ precision=0,
297
+ )
298
+ with gr.Row():
299
+ gradient_accumulation_steps = gr.Number(
300
+ label="Gradient Accumulation Steps",
301
+ value=4,
302
+ visible=True,
303
+ interactive=True,
304
+ elem_id="gradient_accumulation_steps",
305
+ precision=0,
306
+ )
307
+ resolution = gr.Number(
308
+ label="Resolution",
309
+ value=1024,
310
+ visible=True,
311
+ interactive=True,
312
+ elem_id="resolution",
313
+ precision=0,
314
+ )
315
+ scheduler = gr.Dropdown(
316
+ label="Scheduler",
317
+ choices=["cosine", "linear", "constant"],
318
+ value="constant",
319
+ visible=True,
320
+ interactive=True,
321
+ elem_id="scheduler",
322
+ )
323
+ with gr.Column():
324
+ with gr.Group():
325
+ fp16 = gr.Checkbox(
326
+ label="FP16",
327
+ value=True,
328
+ visible=True,
329
+ interactive=True,
330
+ elem_id="fp16",
331
+ )
332
+ prior_preservation = gr.Checkbox(
333
+ label="Prior Preservation",
334
+ value=False,
335
+ visible=True,
336
+ interactive=True,
337
+ elem_id="prior_preservation",
338
+ )
339
+ scale_lr = gr.Checkbox(
340
+ label="Scale LR",
341
+ value=False,
342
+ visible=True,
343
+ interactive=True,
344
+ elem_id="scale_lr",
345
+ )
346
+ use_8bit_adam = gr.Checkbox(
347
+ label="Use 8bit Adam",
348
+ value=True,
349
+ visible=True,
350
+ interactive=True,
351
+ elem_id="use_8bit_adam",
352
+ )
353
+ train_text_encoder = gr.Checkbox(
354
+ label="Train Text Encoder",
355
+ value=False,
356
+ visible=True,
357
+ interactive=True,
358
+ elem_id="train_text_encoder",
359
+ )
360
+ gradient_checkpointing = gr.Checkbox(
361
+ label="Gradient Checkpointing",
362
+ value=False,
363
+ visible=True,
364
+ interactive=True,
365
+ elem_id="gradient_checkpointing",
366
+ )
367
+ center_crop = gr.Checkbox(
368
+ label="Center Crop",
369
+ value=False,
370
+ visible=True,
371
+ interactive=True,
372
+ elem_id="center_crop",
373
+ )
374
+ with gr.Accordion("Advanced Parameters", open=False):
375
+ with gr.Row():
376
+ prior_loss_weight = gr.Number(
377
+ label="Prior Loss Weight",
378
+ value=1.0,
379
+ visible=True,
380
+ interactive=True,
381
+ elem_id="prior_loss_weight",
382
+ )
383
+ num_cycles = gr.Number(
384
+ label="Num Cycles",
385
+ value=1,
386
+ visible=True,
387
+ interactive=True,
388
+ elem_id="num_cycles",
389
+ precision=0,
390
+ )
391
+ lr_power = gr.Number(
392
+ label="LR Power",
393
+ value=1,
394
+ visible=True,
395
+ interactive=True,
396
+ elem_id="lr_power",
397
+ )
398
+
399
+ adam_beta1 = gr.Number(
400
+ label="Adam Beta1",
401
+ value=0.9,
402
+ visible=True,
403
+ interactive=True,
404
+ elem_id="adam_beta1",
405
+ )
406
+ adam_beta2 = gr.Number(
407
+ label="Adam Beta2",
408
+ value=0.999,
409
+ visible=True,
410
+ interactive=True,
411
+ elem_id="adam_beta2",
412
+ )
413
+ adam_weight_decay = gr.Number(
414
+ label="Adam Weight Decay",
415
+ value=1e-2,
416
+ visible=True,
417
+ interactive=True,
418
+ elem_id="adam_weight_decay",
419
+ )
420
+ adam_epsilon = gr.Number(
421
+ label="Adam Epsilon",
422
+ value=1e-8,
423
+ visible=True,
424
+ interactive=True,
425
+ elem_id="adam_epsilon",
426
+ )
427
+ max_grad_norm = gr.Number(
428
+ label="Max Grad Norm",
429
+ value=1,
430
+ visible=True,
431
+ interactive=True,
432
+ elem_id="max_grad_norm",
433
+ )
434
+ warmup_steps = gr.Number(
435
+ label="Warmup Steps",
436
+ value=0,
437
+ visible=True,
438
+ interactive=True,
439
+ elem_id="warmup_steps",
440
+ precision=0,
441
+ )
442
+
443
+ train_button = gr.Button(value="Train", elem_id="train")
444
+ output_md = gr.Markdown("## Output")
445
+ inputs = [
446
+ hub_token,
447
+ project_name,
448
+ model,
449
+ images,
450
+ prompt,
451
+ learning_rate,
452
+ num_steps,
453
+ batch_size,
454
+ gradient_accumulation_steps,
455
+ prior_preservation,
456
+ scale_lr,
457
+ use_8bit_adam,
458
+ train_text_encoder,
459
+ gradient_checkpointing,
460
+ center_crop,
461
+ prior_loss_weight,
462
+ num_cycles,
463
+ lr_power,
464
+ adam_beta1,
465
+ adam_beta2,
466
+ adam_weight_decay,
467
+ adam_epsilon,
468
+ max_grad_norm,
469
+ warmup_steps,
470
+ scheduler,
471
+ resolution,
472
+ fp16,
473
+ ]
474
+
475
+ train_button.click(_run_training, inputs=inputs, outputs=output_md)
476
+ demo.load(
477
+ _update_project_name,
478
+ outputs=[project_name, train_button],
479
+ )
480
+ return demo
481
+
482
+
483
+ if __name__ == "__main__":
484
+ demo = main()
485
+ demo.launch()
autotrain-advanced/src/autotrain/help.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ APP_AUTOTRAIN_USERNAME = """Please choose the user or organization who is creating the AutoTrain Project.
2
+ In case of non-free tier, this user or organization will be billed.
3
+ """
4
+
5
+ APP_PROJECT_NAME = """A unique name for the AutoTrain Project.
6
+ This name will be used to identify the project in the AutoTrain dashboard."""
7
+
8
+
9
+ APP_IMAGE_CLASSIFICATION_DATA_HELP = """The data for the Image Classification task should be in the following format:
10
+ - The data should be in a zip file.
11
+ - The zip file should contain multiple folders (the classes), each folder should contain images of a single class.
12
+ - The name of the folder should be the name of the class.
13
+ - The images must be jpeg, jpg or png.
14
+ - There should be at least 5 images per class.
15
+ - There should not be any other files in the zip file.
16
+ - There should not be any other folders inside the zip folder.
17
+ """
18
+
19
+ APP_LM_TRAINING_TYPE = """There are two types of Language Model Training:
20
+ - generic
21
+ - chat
22
+
23
+ In the generic mode, you provide a CSV with a text column which has already been formatted by you for training a language model.
24
+ In the chat mode, you provide a CSV with two or three text columns: prompt, context (optional) and response.
25
+ Context column can be empty for samples if not needed. You can also have a "prompt start" column. If provided, "prompt start" will be prepended before the prompt column.
26
+
27
+ Please see [this](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset which has both formats in the same dataset.
28
+ """
autotrain-advanced/src/autotrain/infer/__init__.py ADDED
File without changes
autotrain-advanced/src/autotrain/infer/text_generation.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
6
+
7
+
8
+ @dataclass
9
+ class TextGenerationInference:
10
+ model_path: str = "gpt2"
11
+ use_int4: Optional[bool] = False
12
+ use_int8: Optional[bool] = False
13
+ temperature: Optional[float] = 1.0
14
+ top_k: Optional[int] = 50
15
+ top_p: Optional[float] = 0.95
16
+ repetition_penalty: Optional[float] = 1.0
17
+ num_return_sequences: Optional[int] = 1
18
+ num_beams: Optional[int] = 1
19
+ max_new_tokens: Optional[int] = 1024
20
+ do_sample: Optional[bool] = True
21
+
22
+ def __post_init__(self):
23
+ self.model = AutoModelForCausalLM.from_pretrained(
24
+ self.model_path,
25
+ load_in_4bit=self.use_int4,
26
+ load_in_8bit=self.use_int8,
27
+ torch_dtype=torch.float16,
28
+ trust_remote_code=True,
29
+ device_map="auto",
30
+ )
31
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
32
+ self.model.eval()
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ self.generation_config = GenerationConfig(
35
+ temperature=self.temperature,
36
+ top_k=self.top_k,
37
+ top_p=self.top_p,
38
+ repetition_penalty=self.repetition_penalty,
39
+ num_return_sequences=self.num_return_sequences,
40
+ num_beams=self.num_beams,
41
+ max_length=self.max_new_tokens,
42
+ eos_token_id=self.tokenizer.eos_token_id,
43
+ do_sample=self.do_sample,
44
+ max_new_tokens=self.max_new_tokens,
45
+ )
46
+
47
+ def chat(self, prompt):
48
+ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
49
+ outputs = self.model.generate(**inputs, generation_config=self.generation_config)
50
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
autotrain-advanced/src/autotrain/languages.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SUPPORTED_LANGUAGES = [
2
+ "en",
3
+ "ar",
4
+ "bn",
5
+ "de",
6
+ "es",
7
+ "fi",
8
+ "fr",
9
+ "hi",
10
+ "it",
11
+ "ja",
12
+ "ko",
13
+ "nl",
14
+ "pt",
15
+ "sv",
16
+ "tr",
17
+ "zh",
18
+ "unk",
19
+ ]
autotrain-advanced/src/autotrain/params.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import gradio as gr
5
+ from pydantic import BaseModel, Field
6
+
7
+ from autotrain.languages import SUPPORTED_LANGUAGES
8
+ from autotrain.tasks import TASKS
9
+
10
+
11
+ class LoraR:
12
+ TYPE = "int"
13
+ MIN_VALUE = 1
14
+ MAX_VALUE = 100
15
+ DEFAULT = 16
16
+ STEP = 1
17
+ STREAMLIT_INPUT = "number_input"
18
+ PRETTY_NAME = "LoRA R"
19
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
20
+
21
+
22
+ class LoraAlpha:
23
+ TYPE = "int"
24
+ MIN_VALUE = 1
25
+ MAX_VALUE = 256
26
+ DEFAULT = 32
27
+ STEP = 1
28
+ STREAMLIT_INPUT = "number_input"
29
+ PRETTY_NAME = "LoRA Alpha"
30
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
31
+
32
+
33
+ class LoraDropout:
34
+ TYPE = "float"
35
+ MIN_VALUE = 0.0
36
+ MAX_VALUE = 1.0
37
+ DEFAULT = 0.05
38
+ STEP = 0.01
39
+ STREAMLIT_INPUT = "number_input"
40
+ PRETTY_NAME = "LoRA Dropout"
41
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
42
+
43
+
44
+ class LearningRate:
45
+ TYPE = "float"
46
+ MIN_VALUE = 1e-7
47
+ MAX_VALUE = 1e-1
48
+ DEFAULT = 1e-3
49
+ STEP = 1e-6
50
+ FORMAT = "%.2E"
51
+ STREAMLIT_INPUT = "number_input"
52
+ PRETTY_NAME = "Learning Rate"
53
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
54
+
55
+
56
+ class LMLearningRate(LearningRate):
57
+ DEFAULT = 5e-5
58
+
59
+
60
+ class Optimizer:
61
+ TYPE = "str"
62
+ DEFAULT = "adamw_torch"
63
+ CHOICES = ["adamw_torch", "adamw_hf", "sgd", "adafactor", "adagrad"]
64
+ STREAMLIT_INPUT = "selectbox"
65
+ PRETTY_NAME = "Optimizer"
66
+ GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
67
+
68
+
69
+ class LMTrainingType:
70
+ TYPE = "str"
71
+ DEFAULT = "generic"
72
+ CHOICES = ["generic", "chat"]
73
+ STREAMLIT_INPUT = "selectbox"
74
+ PRETTY_NAME = "LM Training Type"
75
+ GRAIDO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
76
+
77
+
78
+ class Scheduler:
79
+ TYPE = "str"
80
+ DEFAULT = "linear"
81
+ CHOICES = ["linear", "cosine"]
82
+ STREAMLIT_INPUT = "selectbox"
83
+ PRETTY_NAME = "Scheduler"
84
+ GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
85
+
86
+
87
+ class TrainBatchSize:
88
+ TYPE = "int"
89
+ MIN_VALUE = 1
90
+ MAX_VALUE = 128
91
+ DEFAULT = 2
92
+ STEP = 2
93
+ STREAMLIT_INPUT = "number_input"
94
+ PRETTY_NAME = "Train Batch Size"
95
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
96
+
97
+
98
+ class LMTrainBatchSize(TrainBatchSize):
99
+ DEFAULT = 4
100
+
101
+
102
+ class Epochs:
103
+ TYPE = "int"
104
+ MIN_VALUE = 1
105
+ MAX_VALUE = 1000
106
+ DEFAULT = 10
107
+ STREAMLIT_INPUT = "number_input"
108
+ PRETTY_NAME = "Epochs"
109
+ GRADIO_INPUT = gr.Number(value=DEFAULT)
110
+
111
+
112
+ class LMEpochs(Epochs):
113
+ DEFAULT = 1
114
+
115
+
116
+ class PercentageWarmup:
117
+ TYPE = "float"
118
+ MIN_VALUE = 0.0
119
+ MAX_VALUE = 1.0
120
+ DEFAULT = 0.1
121
+ STEP = 0.01
122
+ STREAMLIT_INPUT = "number_input"
123
+ PRETTY_NAME = "Percentage Warmup"
124
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
125
+
126
+
127
+ class GradientAccumulationSteps:
128
+ TYPE = "int"
129
+ MIN_VALUE = 1
130
+ MAX_VALUE = 100
131
+ DEFAULT = 1
132
+ STREAMLIT_INPUT = "number_input"
133
+ PRETTY_NAME = "Gradient Accumulation Steps"
134
+ GRADIO_INPUT = gr.Number(value=DEFAULT)
135
+
136
+
137
+ class WeightDecay:
138
+ TYPE = "float"
139
+ MIN_VALUE = 0.0
140
+ MAX_VALUE = 1.0
141
+ DEFAULT = 0.0
142
+ STREAMLIT_INPUT = "number_input"
143
+ PRETTY_NAME = "Weight Decay"
144
+ GRADIO_INPUT = gr.Number(value=DEFAULT)
145
+
146
+
147
+ class SourceLanguage:
148
+ TYPE = "str"
149
+ DEFAULT = "en"
150
+ CHOICES = SUPPORTED_LANGUAGES
151
+ STREAMLIT_INPUT = "selectbox"
152
+ PRETTY_NAME = "Source Language"
153
+ GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
154
+
155
+
156
+ class TargetLanguage:
157
+ TYPE = "str"
158
+ DEFAULT = "en"
159
+ CHOICES = SUPPORTED_LANGUAGES
160
+ STREAMLIT_INPUT = "selectbox"
161
+ PRETTY_NAME = "Target Language"
162
+ GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
163
+
164
+
165
+ class NumModels:
166
+ TYPE = "int"
167
+ MIN_VALUE = 1
168
+ MAX_VALUE = 25
169
+ DEFAULT = 1
170
+ STREAMLIT_INPUT = "number_input"
171
+ PRETTY_NAME = "Number of Models"
172
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=1)
173
+
174
+
175
+ class DBNumSteps:
176
+ TYPE = "int"
177
+ MIN_VALUE = 100
178
+ MAX_VALUE = 10000
179
+ DEFAULT = 1500
180
+ STREAMLIT_INPUT = "number_input"
181
+ PRETTY_NAME = "Number of Steps"
182
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=100)
183
+
184
+
185
+ class DBTextEncoderStepsPercentage:
186
+ TYPE = "int"
187
+ MIN_VALUE = 1
188
+ MAX_VALUE = 100
189
+ DEFAULT = 30
190
+ STREAMLIT_INPUT = "number_input"
191
+ PRETTY_NAME = "Text encoder steps percentage"
192
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=1)
193
+
194
+
195
+ class DBPriorPreservation:
196
+ TYPE = "bool"
197
+ DEFAULT = False
198
+ STREAMLIT_INPUT = "checkbox"
199
+ PRETTY_NAME = "Prior preservation"
200
+ GRADIO_INPUT = gr.Dropdown(["True", "False"], value="False")
201
+
202
+
203
+ class ImageSize:
204
+ TYPE = "int"
205
+ MIN_VALUE = 64
206
+ MAX_VALUE = 2048
207
+ DEFAULT = 512
208
+ STREAMLIT_INPUT = "number_input"
209
+ PRETTY_NAME = "Image Size"
210
+ GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=64)
211
+
212
+
213
+ class DreamboothConceptType:
214
+ TYPE = "str"
215
+ DEFAULT = "person"
216
+ CHOICES = ["person", "object"]
217
+ STREAMLIT_INPUT = "selectbox"
218
+ PRETTY_NAME = "Concept Type"
219
+ GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
220
+
221
+
222
+ class SourceLanguageUnk:
223
+ TYPE = "str"
224
+ DEFAULT = "unk"
225
+ CHOICES = ["unk"]
226
+ STREAMLIT_INPUT = "selectbox"
227
+ PRETTY_NAME = "Source Language"
228
+ GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
229
+
230
+
231
+ class HubModel:
232
+ TYPE = "str"
233
+ DEFAULT = "bert-base-uncased"
234
+ PRETTY_NAME = "Hub Model"
235
+ GRADIO_INPUT = gr.Textbox(lines=1, max_lines=1, label="Hub Model")
236
+
237
+
238
+ class TextBinaryClassificationParams(BaseModel):
239
+ task: Literal["text_binary_classification"]
240
+ learning_rate: float = Field(5e-5, title="Learning rate")
241
+ num_train_epochs: int = Field(3, title="Number of training epochs")
242
+ max_seq_length: int = Field(128, title="Max sequence length")
243
+ train_batch_size: int = Field(32, title="Training batch size")
244
+ warmup_ratio: float = Field(0.1, title="Warmup proportion")
245
+ gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
246
+ optimizer: str = Field("adamw_torch", title="Optimizer")
247
+ scheduler: str = Field("linear", title="Scheduler")
248
+ weight_decay: float = Field(0.0, title="Weight decay")
249
+ max_grad_norm: float = Field(1.0, title="Max gradient norm")
250
+ seed: int = Field(42, title="Seed")
251
+
252
+
253
+ class TextMultiClassClassificationParams(BaseModel):
254
+ task: Literal["text_multi_class_classification"]
255
+ learning_rate: float = Field(5e-5, title="Learning rate")
256
+ num_train_epochs: int = Field(3, title="Number of training epochs")
257
+ max_seq_length: int = Field(128, title="Max sequence length")
258
+ train_batch_size: int = Field(32, title="Training batch size")
259
+ warmup_ratio: float = Field(0.1, title="Warmup proportion")
260
+ gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
261
+ optimizer: str = Field("adamw_torch", title="Optimizer")
262
+ scheduler: str = Field("linear", title="Scheduler")
263
+ weight_decay: float = Field(0.0, title="Weight decay")
264
+ max_grad_norm: float = Field(1.0, title="Max gradient norm")
265
+ seed: int = Field(42, title="Seed")
266
+
267
+
268
+ class DreamboothParams(BaseModel):
269
+ task: Literal["dreambooth"]
270
+ num_steps: int = Field(1500, title="Number of steps")
271
+ image_size: int = Field(512, title="Image size")
272
+ text_encoder_steps_percentage: int = Field(30, title="Text encoder steps percentage")
273
+ prior_preservation: bool = Field(False, title="Prior preservation")
274
+ learning_rate: float = Field(2e-6, title="Learning rate")
275
+ train_batch_size: int = Field(1, title="Training batch size")
276
+ gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
277
+
278
+
279
+ class ImageBinaryClassificationParams(BaseModel):
280
+ task: Literal["image_binary_classification"]
281
+ learning_rate: float = Field(3e-5, title="Learning rate")
282
+ num_train_epochs: int = Field(3, title="Number of training epochs")
283
+ train_batch_size: int = Field(8, title="Training batch size")
284
+ warmup_ratio: float = Field(0.1, title="Warmup proportion")
285
+ gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
286
+ optimizer: str = Field("adamw_torch", title="Optimizer")
287
+ scheduler: str = Field("linear", title="Scheduler")
288
+ weight_decay: float = Field(0.0, title="Weight decay")
289
+ max_grad_norm: float = Field(1.0, title="Max gradient norm")
290
+ seed: int = Field(42, title="Seed")
291
+
292
+
293
+ class ImageMultiClassClassificationParams(BaseModel):
294
+ task: Literal["image_multi_class_classification"]
295
+ learning_rate: float = Field(3e-5, title="Learning rate")
296
+ num_train_epochs: int = Field(3, title="Number of training epochs")
297
+ train_batch_size: int = Field(8, title="Training batch size")
298
+ warmup_ratio: float = Field(0.1, title="Warmup proportion")
299
+ gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
300
+ optimizer: str = Field("adamw_torch", title="Optimizer")
301
+ scheduler: str = Field("linear", title="Scheduler")
302
+ weight_decay: float = Field(0.0, title="Weight decay")
303
+ max_grad_norm: float = Field(1.0, title="Max gradient norm")
304
+ seed: int = Field(42, title="Seed")
305
+
306
+
307
+ class LMTrainingParams(BaseModel):
308
+ task: Literal["lm_training"]
309
+ learning_rate: float = Field(3e-5, title="Learning rate")
310
+ num_train_epochs: int = Field(3, title="Number of training epochs")
311
+ train_batch_size: int = Field(8, title="Training batch size")
312
+ warmup_ratio: float = Field(0.1, title="Warmup proportion")
313
+ gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
314
+ optimizer: str = Field("adamw_torch", title="Optimizer")
315
+ scheduler: str = Field("linear", title="Scheduler")
316
+ weight_decay: float = Field(0.0, title="Weight decay")
317
+ max_grad_norm: float = Field(1.0, title="Max gradient norm")
318
+ seed: int = Field(42, title="Seed")
319
+ add_eos_token: bool = Field(True, title="Add EOS token")
320
+ block_size: int = Field(-1, title="Block size")
321
+ lora_r: int = Field(16, title="Lora r")
322
+ lora_alpha: int = Field(32, title="Lora alpha")
323
+ lora_dropout: float = Field(0.05, title="Lora dropout")
324
+ training_type: str = Field("generic", title="Training type")
325
+ train_on_inputs: bool = Field(False, title="Train on inputs")
326
+
327
+
328
+ @dataclass
329
+ class Params:
330
+ task: str
331
+ param_choice: str
332
+ model_choice: str
333
+
334
+ def __post_init__(self):
335
+ # task should be one of the keys in TASKS
336
+ if self.task not in TASKS:
337
+ raise ValueError(f"task must be one of {TASKS.keys()}")
338
+ self.task_id = TASKS[self.task]
339
+
340
+ if self.param_choice not in ("autotrain", "manual"):
341
+ raise ValueError("param_choice must be either autotrain or manual")
342
+
343
+ if self.model_choice not in ("autotrain", "hub_model"):
344
+ raise ValueError("model_choice must be either autotrain or hub_model")
345
+
346
+ def _dreambooth(self):
347
+ if self.param_choice == "manual":
348
+ return {
349
+ "hub_model": HubModel,
350
+ "image_size": ImageSize,
351
+ "learning_rate": LearningRate,
352
+ "train_batch_size": TrainBatchSize,
353
+ "num_steps": DBNumSteps,
354
+ "gradient_accumulation_steps": GradientAccumulationSteps,
355
+ }
356
+ if self.param_choice == "autotrain":
357
+ if self.model_choice == "hub_model":
358
+ return {
359
+ "hub_model": HubModel,
360
+ "image_size": ImageSize,
361
+ "num_models": NumModels,
362
+ }
363
+ else:
364
+ return {
365
+ "num_models": NumModels,
366
+ }
367
+
368
+ def _tabular_binary_classification(self):
369
+ return {
370
+ "num_models": NumModels,
371
+ }
372
+
373
+ def _lm_training(self):
374
+ if self.param_choice == "manual":
375
+ return {
376
+ "hub_model": HubModel,
377
+ "learning_rate": LMLearningRate,
378
+ "optimizer": Optimizer,
379
+ "scheduler": Scheduler,
380
+ "train_batch_size": LMTrainBatchSize,
381
+ "num_train_epochs": LMEpochs,
382
+ "percentage_warmup": PercentageWarmup,
383
+ "gradient_accumulation_steps": GradientAccumulationSteps,
384
+ "weight_decay": WeightDecay,
385
+ "lora_r": LoraR,
386
+ "lora_alpha": LoraAlpha,
387
+ "lora_dropout": LoraDropout,
388
+ "training_type": LMTrainingType,
389
+ }
390
+ if self.param_choice == "autotrain":
391
+ if self.model_choice == "autotrain":
392
+ return {
393
+ "num_models": NumModels,
394
+ "training_type": LMTrainingType,
395
+ }
396
+ else:
397
+ return {
398
+ "hub_model": HubModel,
399
+ "num_models": NumModels,
400
+ "training_type": LMTrainingType,
401
+ }
402
+ raise ValueError("param_choice must be either autotrain or manual")
403
+
404
+ def _tabular_multi_class_classification(self):
405
+ return self._tabular_binary_classification()
406
+
407
+ def _tabular_single_column_regression(self):
408
+ return self._tabular_binary_classification()
409
+
410
+ def tabular_multi_label_classification(self):
411
+ return self._tabular_binary_classification()
412
+
413
+ def _text_binary_classification(self):
414
+ if self.param_choice == "manual":
415
+ return {
416
+ "hub_model": HubModel,
417
+ "learning_rate": LearningRate,
418
+ "optimizer": Optimizer,
419
+ "scheduler": Scheduler,
420
+ "train_batch_size": TrainBatchSize,
421
+ "num_train_epochs": Epochs,
422
+ "percentage_warmup": PercentageWarmup,
423
+ "gradient_accumulation_steps": GradientAccumulationSteps,
424
+ "weight_decay": WeightDecay,
425
+ }
426
+ if self.param_choice == "autotrain":
427
+ if self.model_choice == "autotrain":
428
+ return {
429
+ "source_language": SourceLanguage,
430
+ "num_models": NumModels,
431
+ }
432
+ return {
433
+ "hub_model": HubModel,
434
+ "source_language": SourceLanguageUnk,
435
+ "num_models": NumModels,
436
+ }
437
+ raise ValueError("param_choice must be either autotrain or manual")
438
+
439
+ def _text_multi_class_classification(self):
440
+ return self._text_binary_classification()
441
+
442
+ def _text_entity_extraction(self):
443
+ return self._text_binary_classification()
444
+
445
+ def _text_single_column_regression(self):
446
+ return self._text_binary_classification()
447
+
448
+ def _text_natural_language_inference(self):
449
+ return self._text_binary_classification()
450
+
451
+ def _image_binary_classification(self):
452
+ if self.param_choice == "manual":
453
+ return {
454
+ "hub_model": HubModel,
455
+ "learning_rate": LearningRate,
456
+ "optimizer": Optimizer,
457
+ "scheduler": Scheduler,
458
+ "train_batch_size": TrainBatchSize,
459
+ "num_train_epochs": Epochs,
460
+ "percentage_warmup": PercentageWarmup,
461
+ "gradient_accumulation_steps": GradientAccumulationSteps,
462
+ "weight_decay": WeightDecay,
463
+ }
464
+ if self.param_choice == "autotrain":
465
+ if self.model_choice == "autotrain":
466
+ return {
467
+ "num_models": NumModels,
468
+ }
469
+ return {
470
+ "hub_model": HubModel,
471
+ "num_models": NumModels,
472
+ }
473
+ raise ValueError("param_choice must be either autotrain or manual")
474
+
475
+ def _image_multi_class_classification(self):
476
+ return self._image_binary_classification()
477
+
478
+ def get(self):
479
+ if self.task in ("text_binary_classification", "text_multi_class_classification"):
480
+ return self._text_binary_classification()
481
+
482
+ if self.task == "text_entity_extraction":
483
+ return self._text_entity_extraction()
484
+
485
+ if self.task == "text_single_column_regression":
486
+ return self._text_single_column_regression()
487
+
488
+ if self.task == "text_natural_language_inference":
489
+ return self._text_natural_language_inference()
490
+
491
+ if self.task == "tabular_binary_classification":
492
+ return self._tabular_binary_classification()
493
+
494
+ if self.task == "tabular_multi_class_classification":
495
+ return self._tabular_multi_class_classification()
496
+
497
+ if self.task == "tabular_single_column_regression":
498
+ return self._tabular_single_column_regression()
499
+
500
+ if self.task == "tabular_multi_label_classification":
501
+ return self.tabular_multi_label_classification()
502
+
503
+ if self.task in ("image_binary_classification", "image_multi_class_classification"):
504
+ return self._image_binary_classification()
505
+
506
+ if self.task == "dreambooth":
507
+ return self._dreambooth()
508
+
509
+ if self.task == "lm_training":
510
+ return self._lm_training()
511
+
512
+ raise ValueError(f"task {self.task} not supported")
autotrain-advanced/src/autotrain/preprocessor/__init__.py ADDED
File without changes
autotrain-advanced/src/autotrain/preprocessor/dreambooth.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+ from dataclasses import dataclass
4
+ from typing import Any, List
5
+
6
+ from huggingface_hub import HfApi, create_repo
7
+ from loguru import logger
8
+
9
+
10
+ @dataclass
11
+ class DreamboothPreprocessor:
12
+ concept_images: List[Any]
13
+ concept_name: str
14
+ username: str
15
+ project_name: str
16
+ token: str
17
+
18
+ def __post_init__(self):
19
+ self.repo_name = f"{self.username}/autotrain-data-{self.project_name}"
20
+ try:
21
+ create_repo(
22
+ repo_id=self.repo_name,
23
+ repo_type="dataset",
24
+ token=self.token,
25
+ private=True,
26
+ exist_ok=False,
27
+ )
28
+ except Exception:
29
+ logger.error("Error creating repo")
30
+ raise ValueError("Error creating repo")
31
+
32
+ def _upload_concept_images(self, file, api):
33
+ logger.info(f"Uploading {file} to concept1")
34
+ api.upload_file(
35
+ path_or_fileobj=file.name,
36
+ path_in_repo=f"concept1/{file.name.split('/')[-1]}",
37
+ repo_id=self.repo_name,
38
+ repo_type="dataset",
39
+ token=self.token,
40
+ )
41
+
42
+ def _upload_concept_prompts(self, api):
43
+ _prompts = {}
44
+ _prompts["concept1"] = self.concept_name
45
+
46
+ prompts = json.dumps(_prompts)
47
+ prompts = prompts.encode("utf-8")
48
+ prompts = io.BytesIO(prompts)
49
+ api.upload_file(
50
+ path_or_fileobj=prompts,
51
+ path_in_repo="prompts.json",
52
+ repo_id=self.repo_name,
53
+ repo_type="dataset",
54
+ token=self.token,
55
+ )
56
+
57
+ def prepare(self):
58
+ api = HfApi()
59
+ for _file in self.concept_images:
60
+ self._upload_concept_images(_file, api)
61
+
62
+ self._upload_concept_prompts(api)
autotrain-advanced/src/autotrain/preprocessor/tabular.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import pandas as pd
5
+ from datasets import Dataset
6
+ from sklearn.model_selection import train_test_split
7
+
8
+
9
+ RESERVED_COLUMNS = ["autotrain_id", "autotrain_label"]
10
+
11
+
12
+ @dataclass
13
+ class TabularBinaryClassificationPreprocessor:
14
+ train_data: pd.DataFrame
15
+ label_column: str
16
+ username: str
17
+ project_name: str
18
+ id_column: Optional[str] = None
19
+ valid_data: Optional[pd.DataFrame] = None
20
+ test_size: Optional[float] = 0.2
21
+ seed: Optional[int] = 42
22
+
23
+ def __post_init__(self):
24
+ # check if id_column and label_column are in train_data
25
+ if self.id_column is not None:
26
+ if self.id_column not in self.train_data.columns:
27
+ raise ValueError(f"{self.id_column} not in train data")
28
+
29
+ if self.label_column not in self.train_data.columns:
30
+ raise ValueError(f"{self.label_column} not in train data")
31
+
32
+ # check if id_column and label_column are in valid_data
33
+ if self.valid_data is not None:
34
+ if self.id_column is not None:
35
+ if self.id_column not in self.valid_data.columns:
36
+ raise ValueError(f"{self.id_column} not in valid data")
37
+ if self.label_column not in self.valid_data.columns:
38
+ raise ValueError(f"{self.label_column} not in valid data")
39
+
40
+ # make sure no reserved columns are in train_data or valid_data
41
+ for column in RESERVED_COLUMNS:
42
+ if column in self.train_data.columns:
43
+ raise ValueError(f"{column} is a reserved column name")
44
+ if self.valid_data is not None:
45
+ if column in self.valid_data.columns:
46
+ raise ValueError(f"{column} is a reserved column name")
47
+
48
+ def split(self):
49
+ if self.valid_data is not None:
50
+ return self.train_data, self.valid_data
51
+ else:
52
+ train_df, valid_df = train_test_split(
53
+ self.train_data,
54
+ test_size=self.test_size,
55
+ random_state=self.seed,
56
+ stratify=self.train_data[self.label_column],
57
+ )
58
+ train_df = train_df.reset_index(drop=True)
59
+ valid_df = valid_df.reset_index(drop=True)
60
+ return train_df, valid_df
61
+
62
+ def prepare_columns(self, train_df, valid_df):
63
+ train_df.loc[:, "autotrain_id"] = train_df[self.id_column]
64
+ train_df.loc[:, "autotrain_label"] = train_df[self.label_column]
65
+ valid_df.loc[:, "autotrain_id"] = valid_df[self.id_column]
66
+ valid_df.loc[:, "autotrain_label"] = valid_df[self.label_column]
67
+
68
+ # drop id_column and label_column
69
+ train_df = train_df.drop(columns=[self.id_column, self.label_column])
70
+ valid_df = valid_df.drop(columns=[self.id_column, self.label_column])
71
+ return train_df, valid_df
72
+
73
+ def prepare(self):
74
+ train_df, valid_df = self.split()
75
+ train_df, valid_df = self.prepare_columns(train_df, valid_df)
76
+ train_df = Dataset.from_pandas(train_df)
77
+ valid_df = Dataset.from_pandas(valid_df)
78
+ train_df.push_to_hub(f"{self.username}/autotrain-data-{self.project_name}", split="train", private=True)
79
+ valid_df.push_to_hub(f"{self.username}/autotrain-data-{self.project_name}", split="validation", private=True)
80
+ return train_df, valid_df
81
+
82
+
83
+ class TabularMultiClassClassificationPreprocessor(TabularBinaryClassificationPreprocessor):
84
+ pass
85
+
86
+
87
+ class TabularSingleColumnRegressionPreprocessor(TabularBinaryClassificationPreprocessor):
88
+ def split(self):
89
+ if self.valid_data is not None:
90
+ return self.train_data, self.valid_data
91
+ else:
92
+ train_df, valid_df = train_test_split(
93
+ self.train_data,
94
+ test_size=self.test_size,
95
+ random_state=self.seed,
96
+ )
97
+ train_df = train_df.reset_index(drop=True)
98
+ valid_df = valid_df.reset_index(drop=True)
99
+ return train_df, valid_df