Commit ·
c0551d3
1
Parent(s): 31a0704
End of training
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +15 -0
- autotrain-advanced/.dockerignore +9 -0
- autotrain-advanced/.github/workflows/build_documentation.yml +19 -0
- autotrain-advanced/.github/workflows/build_pr_documentation.yml +17 -0
- autotrain-advanced/.github/workflows/code_quality.yml +30 -0
- autotrain-advanced/.github/workflows/delete_doc_comment.yml +13 -0
- autotrain-advanced/.github/workflows/delete_doc_comment_trigger.yml +12 -0
- autotrain-advanced/.github/workflows/tests.yml +30 -0
- autotrain-advanced/.github/workflows/upload_pr_documentation.yml +16 -0
- autotrain-advanced/.gitignore +138 -0
- autotrain-advanced/Dockerfile +65 -0
- autotrain-advanced/LICENSE +202 -0
- autotrain-advanced/Makefile +28 -0
- autotrain-advanced/README.md +13 -0
- autotrain-advanced/docs/source/_toctree.yml +28 -0
- autotrain-advanced/docs/source/cost.mdx +17 -0
- autotrain-advanced/docs/source/dreambooth.mdx +18 -0
- autotrain-advanced/docs/source/getting_started.mdx +29 -0
- autotrain-advanced/docs/source/image_classification.mdx +40 -0
- autotrain-advanced/docs/source/index.mdx +34 -0
- autotrain-advanced/docs/source/llm_finetuning.mdx +43 -0
- autotrain-advanced/docs/source/model_choice.mdx +24 -0
- autotrain-advanced/docs/source/param_choice.mdx +25 -0
- autotrain-advanced/docs/source/support.mdx +12 -0
- autotrain-advanced/docs/source/text_classification.mdx +60 -0
- autotrain-advanced/examples/text_classification_binary.py +77 -0
- autotrain-advanced/examples/text_classification_multiclass.py +77 -0
- autotrain-advanced/requirements.txt +31 -0
- autotrain-advanced/setup.cfg +24 -0
- autotrain-advanced/setup.py +71 -0
- autotrain-advanced/src/autotrain/__init__.py +24 -0
- autotrain-advanced/src/autotrain/app.py +965 -0
- autotrain-advanced/src/autotrain/cli/__init__.py +13 -0
- autotrain-advanced/src/autotrain/cli/accelerated_autotrain.py +0 -0
- autotrain-advanced/src/autotrain/cli/autotrain.py +40 -0
- autotrain-advanced/src/autotrain/cli/run_app.py +55 -0
- autotrain-advanced/src/autotrain/cli/run_dreambooth.py +469 -0
- autotrain-advanced/src/autotrain/cli/run_llm.py +489 -0
- autotrain-advanced/src/autotrain/cli/run_setup.py +61 -0
- autotrain-advanced/src/autotrain/config.py +12 -0
- autotrain-advanced/src/autotrain/dataset.py +344 -0
- autotrain-advanced/src/autotrain/dreambooth_app.py +485 -0
- autotrain-advanced/src/autotrain/help.py +28 -0
- autotrain-advanced/src/autotrain/infer/__init__.py +0 -0
- autotrain-advanced/src/autotrain/infer/text_generation.py +50 -0
- autotrain-advanced/src/autotrain/languages.py +19 -0
- autotrain-advanced/src/autotrain/params.py +512 -0
- autotrain-advanced/src/autotrain/preprocessor/__init__.py +0 -0
- autotrain-advanced/src/autotrain/preprocessor/dreambooth.py +62 -0
- autotrain-advanced/src/autotrain/preprocessor/tabular.py +99 -0
README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
---
|
| 3 |
+
base_model: stabilityai/stable-diffusion-xl-base-1.0
|
| 4 |
+
instance_prompt: a photo of sks dog
|
| 5 |
+
tags:
|
| 6 |
+
- text-to-image
|
| 7 |
+
- diffusers
|
| 8 |
+
- autotrain
|
| 9 |
+
inference: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# DreamBooth trained by AutoTrain
|
| 13 |
+
|
| 14 |
+
Test enoder was not trained.
|
| 15 |
+
|
autotrain-advanced/.dockerignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build/
|
| 2 |
+
dist/
|
| 3 |
+
logs/
|
| 4 |
+
output/
|
| 5 |
+
output2/
|
| 6 |
+
test/
|
| 7 |
+
test.py
|
| 8 |
+
.DS_Store
|
| 9 |
+
.vscode/
|
autotrain-advanced/.github/workflows/build_documentation.yml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Build documentation
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
- doc-builder*
|
| 8 |
+
- v*-release
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
build:
|
| 12 |
+
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
| 13 |
+
with:
|
| 14 |
+
commit_sha: ${{ github.sha }}
|
| 15 |
+
package: autotrain-advanced
|
| 16 |
+
package_name: autotrain
|
| 17 |
+
secrets:
|
| 18 |
+
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
| 19 |
+
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
autotrain-advanced/.github/workflows/build_pr_documentation.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Build PR Documentation
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
pull_request:
|
| 5 |
+
|
| 6 |
+
concurrency:
|
| 7 |
+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
| 8 |
+
cancel-in-progress: true
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
build:
|
| 12 |
+
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
| 13 |
+
with:
|
| 14 |
+
commit_sha: ${{ github.event.pull_request.head.sha }}
|
| 15 |
+
pr_number: ${{ github.event.number }}
|
| 16 |
+
package: autotrain-advanced
|
| 17 |
+
package_name: autotrain
|
autotrain-advanced/.github/workflows/code_quality.yml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Code quality
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
pull_request:
|
| 8 |
+
branches:
|
| 9 |
+
- main
|
| 10 |
+
release:
|
| 11 |
+
types:
|
| 12 |
+
- created
|
| 13 |
+
|
| 14 |
+
jobs:
|
| 15 |
+
check_code_quality:
|
| 16 |
+
name: Check code quality
|
| 17 |
+
runs-on: ubuntu-latest
|
| 18 |
+
steps:
|
| 19 |
+
- uses: actions/checkout@v2
|
| 20 |
+
- name: Set up Python 3.9
|
| 21 |
+
uses: actions/setup-python@v2
|
| 22 |
+
with:
|
| 23 |
+
python-version: 3.9
|
| 24 |
+
- name: Install dependencies
|
| 25 |
+
run: |
|
| 26 |
+
python -m pip install --upgrade pip
|
| 27 |
+
python -m pip install flake8 black isort
|
| 28 |
+
- name: Make quality
|
| 29 |
+
run: |
|
| 30 |
+
make quality
|
autotrain-advanced/.github/workflows/delete_doc_comment.yml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Delete doc comment
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
workflow_run:
|
| 5 |
+
workflows: ["Delete doc comment trigger"]
|
| 6 |
+
types:
|
| 7 |
+
- completed
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
delete:
|
| 11 |
+
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
|
| 12 |
+
secrets:
|
| 13 |
+
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
autotrain-advanced/.github/workflows/delete_doc_comment_trigger.yml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Delete doc comment trigger
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
pull_request:
|
| 5 |
+
types: [ closed ]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
delete:
|
| 10 |
+
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
|
| 11 |
+
with:
|
| 12 |
+
pr_number: ${{ github.event.number }}
|
autotrain-advanced/.github/workflows/tests.yml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
pull_request:
|
| 8 |
+
branches:
|
| 9 |
+
- main
|
| 10 |
+
release:
|
| 11 |
+
types:
|
| 12 |
+
- created
|
| 13 |
+
|
| 14 |
+
jobs:
|
| 15 |
+
tests:
|
| 16 |
+
name: Run unit tests
|
| 17 |
+
runs-on: ubuntu-latest
|
| 18 |
+
steps:
|
| 19 |
+
- uses: actions/checkout@v2
|
| 20 |
+
- name: Set up Python 3.9
|
| 21 |
+
uses: actions/setup-python@v2
|
| 22 |
+
with:
|
| 23 |
+
python-version: 3.9
|
| 24 |
+
- name: Install dependencies
|
| 25 |
+
run: |
|
| 26 |
+
python -m pip install --upgrade pip
|
| 27 |
+
python -m pip install .[dev]
|
| 28 |
+
- name: Make test
|
| 29 |
+
run: |
|
| 30 |
+
make test
|
autotrain-advanced/.github/workflows/upload_pr_documentation.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Upload PR Documentation
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
workflow_run:
|
| 5 |
+
workflows: ["Build PR Documentation"]
|
| 6 |
+
types:
|
| 7 |
+
- completed
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
build:
|
| 11 |
+
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
| 12 |
+
with:
|
| 13 |
+
package_name: autotrain
|
| 14 |
+
secrets:
|
| 15 |
+
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
| 16 |
+
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
autotrain-advanced/.gitignore
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Local stuff
|
| 2 |
+
.DS_Store
|
| 3 |
+
.vscode/
|
| 4 |
+
test/
|
| 5 |
+
test.py
|
| 6 |
+
output/
|
| 7 |
+
output2/
|
| 8 |
+
logs/
|
| 9 |
+
|
| 10 |
+
# Byte-compiled / optimized / DLL files
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*$py.class
|
| 14 |
+
|
| 15 |
+
# C extensions
|
| 16 |
+
*.so
|
| 17 |
+
|
| 18 |
+
# Distribution / packaging
|
| 19 |
+
.Python
|
| 20 |
+
build/
|
| 21 |
+
develop-eggs/
|
| 22 |
+
dist/
|
| 23 |
+
downloads/
|
| 24 |
+
eggs/
|
| 25 |
+
.eggs/
|
| 26 |
+
lib/
|
| 27 |
+
lib64/
|
| 28 |
+
parts/
|
| 29 |
+
sdist/
|
| 30 |
+
var/
|
| 31 |
+
wheels/
|
| 32 |
+
pip-wheel-metadata/
|
| 33 |
+
share/python-wheels/
|
| 34 |
+
*.egg-info/
|
| 35 |
+
.installed.cfg
|
| 36 |
+
*.egg
|
| 37 |
+
MANIFEST
|
| 38 |
+
|
| 39 |
+
# PyInstaller
|
| 40 |
+
# Usually these files are written by a python script from a template
|
| 41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 42 |
+
*.manifest
|
| 43 |
+
*.spec
|
| 44 |
+
|
| 45 |
+
# Installer logs
|
| 46 |
+
pip-log.txt
|
| 47 |
+
pip-delete-this-directory.txt
|
| 48 |
+
|
| 49 |
+
# Unit test / coverage reports
|
| 50 |
+
htmlcov/
|
| 51 |
+
.tox/
|
| 52 |
+
.nox/
|
| 53 |
+
.coverage
|
| 54 |
+
.coverage.*
|
| 55 |
+
.cache
|
| 56 |
+
nosetests.xml
|
| 57 |
+
coverage.xml
|
| 58 |
+
*.cover
|
| 59 |
+
*.py,cover
|
| 60 |
+
.hypothesis/
|
| 61 |
+
.pytest_cache/
|
| 62 |
+
|
| 63 |
+
# Translations
|
| 64 |
+
*.mo
|
| 65 |
+
*.pot
|
| 66 |
+
|
| 67 |
+
# Django stuff:
|
| 68 |
+
*.log
|
| 69 |
+
local_settings.py
|
| 70 |
+
db.sqlite3
|
| 71 |
+
db.sqlite3-journal
|
| 72 |
+
|
| 73 |
+
# Flask stuff:
|
| 74 |
+
instance/
|
| 75 |
+
.webassets-cache
|
| 76 |
+
|
| 77 |
+
# Scrapy stuff:
|
| 78 |
+
.scrapy
|
| 79 |
+
|
| 80 |
+
# Sphinx documentation
|
| 81 |
+
docs/_build/
|
| 82 |
+
|
| 83 |
+
# PyBuilder
|
| 84 |
+
target/
|
| 85 |
+
|
| 86 |
+
# Jupyter Notebook
|
| 87 |
+
.ipynb_checkpoints
|
| 88 |
+
|
| 89 |
+
# IPython
|
| 90 |
+
profile_default/
|
| 91 |
+
ipython_config.py
|
| 92 |
+
|
| 93 |
+
# pyenv
|
| 94 |
+
.python-version
|
| 95 |
+
|
| 96 |
+
# pipenv
|
| 97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 100 |
+
# install all needed dependencies.
|
| 101 |
+
#Pipfile.lock
|
| 102 |
+
|
| 103 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 104 |
+
__pypackages__/
|
| 105 |
+
|
| 106 |
+
# Celery stuff
|
| 107 |
+
celerybeat-schedule
|
| 108 |
+
celerybeat.pid
|
| 109 |
+
|
| 110 |
+
# SageMath parsed files
|
| 111 |
+
*.sage.py
|
| 112 |
+
|
| 113 |
+
# Environments
|
| 114 |
+
.env
|
| 115 |
+
.venv
|
| 116 |
+
env/
|
| 117 |
+
venv/
|
| 118 |
+
ENV/
|
| 119 |
+
env.bak/
|
| 120 |
+
venv.bak/
|
| 121 |
+
|
| 122 |
+
# Spyder project settings
|
| 123 |
+
.spyderproject
|
| 124 |
+
.spyproject
|
| 125 |
+
|
| 126 |
+
# Rope project settings
|
| 127 |
+
.ropeproject
|
| 128 |
+
|
| 129 |
+
# mkdocs documentation
|
| 130 |
+
/site
|
| 131 |
+
|
| 132 |
+
# mypy
|
| 133 |
+
.mypy_cache/
|
| 134 |
+
.dmypy.json
|
| 135 |
+
dmypy.json
|
| 136 |
+
|
| 137 |
+
# Pyre type checker
|
| 138 |
+
.pyre/
|
autotrain-advanced/Dockerfile
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
+
TZ=UTC
|
| 5 |
+
|
| 6 |
+
ENV PATH="${HOME}/miniconda3/bin:${PATH}"
|
| 7 |
+
ARG PATH="${HOME}/miniconda3/bin:${PATH}"
|
| 8 |
+
|
| 9 |
+
RUN mkdir -p /tmp/model
|
| 10 |
+
RUN chown -R 1000:1000 /tmp/model
|
| 11 |
+
RUN mkdir -p /tmp/data
|
| 12 |
+
RUN chown -R 1000:1000 /tmp/data
|
| 13 |
+
|
| 14 |
+
RUN apt-get update && \
|
| 15 |
+
apt-get upgrade -y && \
|
| 16 |
+
apt-get install -y \
|
| 17 |
+
build-essential \
|
| 18 |
+
cmake \
|
| 19 |
+
curl \
|
| 20 |
+
ca-certificates \
|
| 21 |
+
gcc \
|
| 22 |
+
git \
|
| 23 |
+
locales \
|
| 24 |
+
net-tools \
|
| 25 |
+
wget \
|
| 26 |
+
libpq-dev \
|
| 27 |
+
libsndfile1-dev \
|
| 28 |
+
git \
|
| 29 |
+
git-lfs \
|
| 30 |
+
libgl1 \
|
| 31 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
|
| 35 |
+
git lfs install
|
| 36 |
+
|
| 37 |
+
WORKDIR /app
|
| 38 |
+
RUN mkdir -p /app/.cache
|
| 39 |
+
ENV HF_HOME="/app/.cache"
|
| 40 |
+
RUN chown -R 1000:1000 /app
|
| 41 |
+
USER 1000
|
| 42 |
+
ENV HOME=/app
|
| 43 |
+
|
| 44 |
+
ENV PYTHONPATH=$HOME/app \
|
| 45 |
+
PYTHONUNBUFFERED=1 \
|
| 46 |
+
GRADIO_ALLOW_FLAGGING=never \
|
| 47 |
+
GRADIO_NUM_PORTS=1 \
|
| 48 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
| 49 |
+
SYSTEM=spaces
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
| 53 |
+
&& sh Miniconda3-latest-Linux-x86_64.sh -b -p /app/miniconda \
|
| 54 |
+
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
| 55 |
+
ENV PATH /app/miniconda/bin:$PATH
|
| 56 |
+
|
| 57 |
+
RUN conda create -p /app/env -y python=3.9
|
| 58 |
+
|
| 59 |
+
SHELL ["conda", "run","--no-capture-output", "-p","/app/env", "/bin/bash", "-c"]
|
| 60 |
+
|
| 61 |
+
RUN conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
|
| 62 |
+
RUN pip install git+https://github.com/huggingface/peft.git
|
| 63 |
+
COPY --chown=1000:1000 . /app/
|
| 64 |
+
|
| 65 |
+
RUN pip install -e .
|
autotrain-advanced/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
autotrain-advanced/Makefile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: quality style test
|
| 2 |
+
|
| 3 |
+
# Check that source code meets quality standards
|
| 4 |
+
|
| 5 |
+
quality:
|
| 6 |
+
black --check --line-length 119 --target-version py38 .
|
| 7 |
+
isort --check-only .
|
| 8 |
+
flake8 --max-line-length 119
|
| 9 |
+
|
| 10 |
+
# Format source code automatically
|
| 11 |
+
|
| 12 |
+
style:
|
| 13 |
+
black --line-length 119 --target-version py38 .
|
| 14 |
+
isort .
|
| 15 |
+
|
| 16 |
+
test:
|
| 17 |
+
pytest -sv ./src/
|
| 18 |
+
|
| 19 |
+
docker:
|
| 20 |
+
docker build -t autotrain-advanced:latest .
|
| 21 |
+
docker tag autotrain-advanced:latest huggingface/autotrain-advanced:latest
|
| 22 |
+
docker push huggingface/autotrain-advanced:latest
|
| 23 |
+
|
| 24 |
+
pip:
|
| 25 |
+
rm -rf build/
|
| 26 |
+
rm -rf dist/
|
| 27 |
+
python setup.py sdist bdist_wheel
|
| 28 |
+
twine upload dist/* --verbose
|
autotrain-advanced/README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🤗 AutoTrain Advanced
|
| 2 |
+
|
| 3 |
+
AutoTrain Advanced: faster and easier training and deployments of state-of-the-art machine learning models
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
You can Install AutoTrain-Advanced python package via PIP. Please note you will need python >= 3.8 for AutoTrain Advanced to work properly.
|
| 8 |
+
|
| 9 |
+
pip install autotrain-advanced
|
| 10 |
+
|
| 11 |
+
Please make sure that you have git lfs installed. Check out the instructions here: https://github.com/git-lfs/git-lfs/wiki/Installation
|
| 12 |
+
|
| 13 |
+
## Coming Soon!
|
autotrain-advanced/docs/source/_toctree.yml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- sections:
|
| 2 |
+
- local: index
|
| 3 |
+
title: 🤗 AutoTrain
|
| 4 |
+
- local: getting_started
|
| 5 |
+
title: Installation
|
| 6 |
+
- local: cost
|
| 7 |
+
title: How much does it cost?
|
| 8 |
+
- local: support
|
| 9 |
+
title: Get help and support
|
| 10 |
+
title: Get started
|
| 11 |
+
- sections:
|
| 12 |
+
- local: model_choice
|
| 13 |
+
title: Model Selection
|
| 14 |
+
- local: param_choice
|
| 15 |
+
title: Parameter Selection
|
| 16 |
+
title: Selecting Models and Parameters
|
| 17 |
+
- sections:
|
| 18 |
+
- local: text_classification
|
| 19 |
+
title: Text Classification
|
| 20 |
+
- local: llm_finetuning
|
| 21 |
+
title: LLM Finetuning
|
| 22 |
+
title: Text Tasks
|
| 23 |
+
- sections:
|
| 24 |
+
- local: image_classification
|
| 25 |
+
title: Image Classification
|
| 26 |
+
- local: dreambooth
|
| 27 |
+
title: DreamBooth
|
| 28 |
+
title: Image Tasks
|
autotrain-advanced/docs/source/cost.mdx
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How much does it cost?
|
| 2 |
+
|
| 3 |
+
AutoTrain provides you with best models which are deployable with just a few clicks.
|
| 4 |
+
Unlike other services, we don't own your models. Once the training is done, you can download them and use them anywhere you want.
|
| 5 |
+
|
| 6 |
+
Before you start training, you can see the estimated cost of training.
|
| 7 |
+
|
| 8 |
+
Free tier is available for everyone. For a limited number of samples, you can train your models for free!
|
| 9 |
+
If your dataset is larger, you will be presented with the estimated cost of training.
|
| 10 |
+
Training will begin only after you confirm the payment.
|
| 11 |
+
|
| 12 |
+
Please note that in order to use non-free tier AutoTrain, you need to have a valid payment method on file.
|
| 13 |
+
You can add your payment method in the [billing](https://huggingface.co/settings/billing) section.
|
| 14 |
+
|
| 15 |
+
Estimated cost will be displayed in the UI as follows:
|
| 16 |
+
|
| 17 |
+

|
autotrain-advanced/docs/source/dreambooth.mdx
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DreamBooth
|
| 2 |
+
|
| 3 |
+
DreamBooth is a method to personalize text-to-image models like Stable Diffusion given just a few (3-5) images of a subject. It allows the model to generate contextualized images of the subject in different scenes, poses, and views.
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
## Data Preparation
|
| 8 |
+
|
| 9 |
+
The data format for DreamBooth training is simple. All you need is images of a concept (e.g. a person) and a concept token.
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+
|
| 13 |
+
To train a dreambooth model, please select an appropriate model from the hub. You can also let AutoTrain decide the best model for you!
|
| 14 |
+
When choosing a model from the hub, please make sure you select the correct image size compatible with the model.
|
| 15 |
+
|
| 16 |
+
Same as other tasks, you also have an option to select the parameters manually or automatically using AutoTrain.
|
| 17 |
+
|
| 18 |
+
For each concept that you want to train, you must have a concept token and concept images. Concept token is nothing but a word that is not available in the dictionary.
|
autotrain-advanced/docs/source/getting_started.mdx
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Installation
|
| 2 |
+
|
| 3 |
+
There is no installation required! AutoTrain Advanced runs on Hugging Face Spaces. All you need to do is create a new space with the AutoTrain Advanced template: https://huggingface.co/new-space?template=autotrain-projects/autotrain-advanced. Please make sure you keep the space private.
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
Once you have selected Docker > AutoTrain template. You can click on "Create Space" and you will be redirected to your new space.
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
Once the space is build, you will see this screen:
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
You can find your token at https://huggingface.co/settings/token.
|
| 16 |
+
|
| 17 |
+
Note: you have to add HF_TOKEN as an environment variable in your space settings. To do so, click on the "Settings" button in the top right corner of your space, then click on "New Secret" in the "Repository Secrets" section and add a new variable with the name HF_TOKEN and your token as the value as shown below:
|
| 18 |
+
|
| 19 |
+

|
| 20 |
+
|
| 21 |
+
# Updating AutoTrain Advanced to Latest Version
|
| 22 |
+
|
| 23 |
+
We are constantly adding new features and tasks to AutoTrain Advanced. Its always a good idea to update your space to the latest version before starting a new project. An up-to-date version of AutoTrain Advanced will have the latest tasks, features and bug fixes! Updating is as easy as clicking on the "Factory reboot" button in the setting page of your space.
|
| 24 |
+
|
| 25 |
+

|
| 26 |
+
|
| 27 |
+
Please note that "restarting" a space will not update it to the latest version. You need to "Factory reboot" the space to update it to the latest version.
|
| 28 |
+
|
| 29 |
+
And now we are all set and we can start with our first project!
|
autotrain-advanced/docs/source/image_classification.mdx
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image Classification
|
| 2 |
+
|
| 3 |
+
Image classification is a supervised learning problem: define a set of target classes (objects to identify in images), and train a model to recognize them using labeled example photos.
|
| 4 |
+
Using AutoTrain, its super-easy to train a state-of-the-art image classification model. Just upload a set of images, and AutoTrain will automatically train a model to classify them.
|
| 5 |
+
|
| 6 |
+
## Data Preparation
|
| 7 |
+
|
| 8 |
+
The data for image classification must be in zip format, with each class in a separate subfolder. For example, if you want to classify cats and dogs, your zip file should look like this:
|
| 9 |
+
|
| 10 |
+
```
|
| 11 |
+
cats_and_dogs.zip
|
| 12 |
+
├── cats
|
| 13 |
+
│ ├── cat.1.jpg
|
| 14 |
+
│ ├── cat.2.jpg
|
| 15 |
+
│ ├── cat.3.jpg
|
| 16 |
+
│ └── ...
|
| 17 |
+
└── dogs
|
| 18 |
+
├── dog.1.jpg
|
| 19 |
+
├── dog.2.jpg
|
| 20 |
+
├── dog.3.jpg
|
| 21 |
+
└── ...
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Some points to keep in mind:
|
| 25 |
+
|
| 26 |
+
- The zip file should contain multiple folders (the classes), each folder should contain images of a single class.
|
| 27 |
+
- The name of the folder should be the name of the class.
|
| 28 |
+
- The images must be jpeg, jpg or png.
|
| 29 |
+
- There should be at least 5 images per class.
|
| 30 |
+
- There should not be any other files in the zip file.
|
| 31 |
+
- There should not be any other folders inside the zip folder.
|
| 32 |
+
|
| 33 |
+
When train.zip is decompressed, it creates two folders: cats and dogs. these are the two categories for classification. The images for both categories are in their respective folders. You can have as many categories as you want.
|
| 34 |
+
|
| 35 |
+
## Training
|
| 36 |
+
|
| 37 |
+
Once you have your data ready, you can upload it to AutoTrain and select model and parameters.
|
| 38 |
+
If the estimate looks good, click on `Create Project` button to start training.
|
| 39 |
+
|
| 40 |
+

|
autotrain-advanced/docs/source/index.mdx
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AutoTrain
|
| 2 |
+
|
| 3 |
+
🤗 AutoTrain is a no-code tool for training state-of-the-art models for Natural Language Processing (NLP) tasks, for Computer Vision (CV) tasks, and for Speech tasks and even for Tabular tasks. It is built on top of the awesome tools developed by the Hugging Face team, and it is designed to be easy to use.
|
| 4 |
+
|
| 5 |
+
## Who should use AutoTrain?
|
| 6 |
+
|
| 7 |
+
AutoTrain is for anyone who wants to train a state-of-the-art model for a NLP, CV, Speech or Tabular task, but doesn't want to spend time on the technical details of training a model. AutoTrain is also for anyone who wants to train a model for a custom dataset, but doesn't want to spend time on the technical details of training a model. Our goal is to make it easy for anyone to train a state-of-the-art model for any task and our focus is not just data scientists or machine learning engineers, but also non-technical users.
|
| 8 |
+
|
| 9 |
+
## How to use AutoTrain?
|
| 10 |
+
|
| 11 |
+
We offer several ways to use AutoTrain:
|
| 12 |
+
|
| 13 |
+
- No code users with large number of data samples can use `AutoTrain Advanced` by creating a new space with AutoTrain Docker image: https://huggingface.co/new-space?template=autotrain-projects/autotrain-advanced. Please make sure you keep the space private.
|
| 14 |
+
|
| 15 |
+
- No code users with small number of data samples can use AutoTrain using the UI located at: https://ui.autotrain.huggingface.co/projects. Please note that this UI won't be updated with new tasks and features as frequently as AutoTrain Advanced.
|
| 16 |
+
|
| 17 |
+
- Developers can access and build on top of AutoTrain using python api or run AutoTrain Advanced UI locally. The python api is available in the `autotrain-advanced` package. You can install it using pip:
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
pip install autotrain-advanced
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
- Developers can also use the AutoTrain API directly. The API is available at: https://api.autotrain.huggingface.co/docs
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
## What is AutoTrain Advanced?
|
| 27 |
+
|
| 28 |
+
AutoTrain Advanced processes your data either in a Hugging Face Space or locally (if installed locally using pip). This saves one time since the data processing is not done by the AutoTrain backend, resulting in your job not being queued. AutoTrain Advanced also allows you to use your own hardware (better CPU and RAM) to process the data, thus, making the data processing faster.
|
| 29 |
+
|
| 30 |
+
Using AutoTrain Advanced, advanced users can also control the hyperparameters used for training per job. This allows you to train multiple models with different hyperparameters and compare the results.
|
| 31 |
+
|
| 32 |
+
Everything else is the same as AutoTrain. You can use AutoTrain Advanced to train models for NLP, CV, Speech and Tabular tasks.
|
| 33 |
+
|
| 34 |
+
We recommend using [AutoTrain Advanced](https://huggingface.co/new-space?template=autotrain-projects/autotrain-advanced) since it is faster, more flexible and will have more supported tasks and features in the future.
|
autotrain-advanced/docs/source/llm_finetuning.mdx
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM Finetuning
|
| 2 |
+
|
| 3 |
+
With AutoTrain, you can easily finetune large language models (LLMs) on your own data!
|
| 4 |
+
|
| 5 |
+
AutoTrain supports the following types of LLM finetuning:
|
| 6 |
+
|
| 7 |
+
- Causal Language Modeling (CLM)
|
| 8 |
+
- Masked Language Modeling (MLM) [Coming Soon]
|
| 9 |
+
|
| 10 |
+
For LLM finetuning, only Hugging Face Hub model choice is available.
|
| 11 |
+
User needs to select a model from Hugging Face Hub, that they want to finetune and select the parameters on their own (Manual Parameter Selection),
|
| 12 |
+
or use AutoTrain's Auto Parameter Selection to automatically select the best parameters for the task.
|
| 13 |
+
|
| 14 |
+
## Data Preparation
|
| 15 |
+
|
| 16 |
+
LLM finetuning accepts data in CSV format.
|
| 17 |
+
There are two modes for LLM finetuning: `generic` and `chat`.
|
| 18 |
+
An example dataset with both formats in the same dataset can be found here: https://huggingface.co/datasets/tatsu-lab/alpaca
|
| 19 |
+
|
| 20 |
+
### Generic
|
| 21 |
+
|
| 22 |
+
In generic mode, only one column is required: `text`.
|
| 23 |
+
The user can take care of how the data is formatted for the task.
|
| 24 |
+
A sample instance for this format is presented below:
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
| 28 |
+
|
| 29 |
+
### Instruction: Evaluate this sentence for spelling and grammar mistakes
|
| 30 |
+
|
| 31 |
+
### Input: He finnished his meal and left the resturant
|
| 32 |
+
|
| 33 |
+
### Response: He finished his meal and left the restaurant.
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
Please note that above is the format for instruction finetuning. But in the `generic` mode, you can also finetune on any other format as you want. The data can be changed according to the requirements.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## Training
|
| 42 |
+
|
| 43 |
+
Once you have your data ready and estimate verified, you can start training your model by clicking the "Create Project" button.
|
autotrain-advanced/docs/source/model_choice.mdx
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Choice
|
| 2 |
+
|
| 3 |
+
AutoTrain can automagically select the best models for your task! However, you are also
|
| 4 |
+
allowed to choose the models you want to use. You can choose the most appropriate models
|
| 5 |
+
from the Hugging Face Hub.
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
## AutoTrain Model Choice
|
| 10 |
+
|
| 11 |
+
To let AutoTrain choose the best models for your task, you can use the "AutoTrain"
|
| 12 |
+
in the "Model Choice" section. Once you choose AutoTrain mode, you no longer need to worry about model and parameter selection.
|
| 13 |
+
AutoTrain will automatically select the best models (and parameters) for your task.
|
| 14 |
+
|
| 15 |
+
## Manual Model Choice
|
| 16 |
+
|
| 17 |
+
To choose the models manually, you can use the "HuggingFace Hub" in the "Model Choice" section.
|
| 18 |
+
For example, if you want to use if you are training a text classification task and want to choose Deberta V3 Base for your task
|
| 19 |
+
from https://huggingface.co/microsoft/deberta-v3-base,
|
| 20 |
+
You can choose "HuggingFace Hub" and then write the model name: `microsoft/deberta-v3-base` in the model name field.
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
|
| 24 |
+
Please note that if you are selecting a hub model, you should make sure that it is compatible with your task, otherwise the training will fail.
|
autotrain-advanced/docs/source/param_choice.mdx
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Parameter Choice
|
| 2 |
+
|
| 3 |
+
Just like model choice, you can choose the parameters for your job in two ways: AutoTrain and Manual.
|
| 4 |
+
|
| 5 |
+
## AutoTrain Mode
|
| 6 |
+
|
| 7 |
+
In the AutoTrain mode, the parameters for your task-model pair will be chosen automagically.
|
| 8 |
+
If you choose "AutoTrain" as model choice, you get the AutoTrain mode as the only option.
|
| 9 |
+
If you choose "HuggingFace Hub" as model choice, you get the the option to choose between AutoTrain and Manual mode for parameter choice.
|
| 10 |
+
|
| 11 |
+
An example of AutoTrain mode for a text classification task is shown below:
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
For most of the tasks in AutoTrain parameter selection mode, you will get "Number of Models" as the only parameter to choose. Some tasks like test-classification might ask you about the language of the dataset.
|
| 16 |
+
The more the number of models, the better the final results might be but it might be more expensive too!
|
| 17 |
+
|
| 18 |
+
## Manual Mode
|
| 19 |
+
|
| 20 |
+
Manual model can be used only when you choose "HuggingFace Hub" as model choice. In this mode, you can choose the parameters for your task-model pair manually.
|
| 21 |
+
An example of Manual mode for a text classification task is shown below:
|
| 22 |
+
|
| 23 |
+

|
| 24 |
+
|
| 25 |
+
In the manual mode, you have to add the jobs on your own. So, carefully select your parameters, click on "Add Job" and 💥.
|
autotrain-advanced/docs/source/support.mdx
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Help and Support
|
| 2 |
+
|
| 3 |
+
To get help and support for autotrain, there are 3 ways:
|
| 4 |
+
|
| 5 |
+
- [Create an issue](https://github.com/huggingface/autotrain-advanced/issues/new) in AutoTrain Advanced GitHub repository.
|
| 6 |
+
|
| 7 |
+
- [Ask in the Hugging Face Forum](https://discuss.huggingface.co/c/autotrain/16).
|
| 8 |
+
|
| 9 |
+
- [Email us](mailto:autotrain@hf.co) directly.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Please don't forget to mention your username and project name if you have a specific question about your project.
|
autotrain-advanced/docs/source/text_classification.mdx
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Text Classification
|
| 2 |
+
|
| 3 |
+
Training a text classification model with AutoTrain is super-easy! Get your data ready in
|
| 4 |
+
proper format and then with just a few clicks, your state-of-the-art model will be ready to
|
| 5 |
+
be used in production.
|
| 6 |
+
|
| 7 |
+
## Data Format
|
| 8 |
+
|
| 9 |
+
Let's train a model for classifying the sentiment of a movie review. The data should be
|
| 10 |
+
in the following CSV format:
|
| 11 |
+
|
| 12 |
+
```csv
|
| 13 |
+
review,sentiment
|
| 14 |
+
"this movie is great",positive
|
| 15 |
+
"this movie is bad",negative
|
| 16 |
+
.
|
| 17 |
+
.
|
| 18 |
+
.
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
As you can see, we have two columns in the CSV file. One column is the text and the other
|
| 22 |
+
is the label. The label can be any string. In this example, we have two labels: `positive`
|
| 23 |
+
and `negative`. You can have as many labels as you want.
|
| 24 |
+
|
| 25 |
+
If your CSV is huge, you can divide it into multiple CSV files and upload them separately.
|
| 26 |
+
Please make sure that the column names are the same in all CSV files.
|
| 27 |
+
|
| 28 |
+
One way to divide the CSV file using pandas is as follows:
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
import pandas as pd
|
| 32 |
+
|
| 33 |
+
# Set the chunk size
|
| 34 |
+
chunk_size = 1000
|
| 35 |
+
i = 1
|
| 36 |
+
|
| 37 |
+
# Open the CSV file and read it in chunks
|
| 38 |
+
for chunk in pd.read_csv('example.csv', chunksize=chunk_size):
|
| 39 |
+
# Save each chunk to a new file
|
| 40 |
+
chunk.to_csv(f'chunk_{i}.csv', index=False)
|
| 41 |
+
i += 1
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Once the data has been uploaded, you have to select the proper column mapping
|
| 45 |
+
|
| 46 |
+
## Column Mapping
|
| 47 |
+
|
| 48 |
+

|
| 49 |
+
|
| 50 |
+
In our example, the text column is called `review` and the label column is called `sentiment`.
|
| 51 |
+
Thus, we have to select `review` for the text column and `sentiment` for the label column.
|
| 52 |
+
Please note that, if column mapping is not done correctly, the training will fail.
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
## Training
|
| 56 |
+
|
| 57 |
+
Once you have uploaded the data, selected the column mapping, and set the hyperparameters (AutoTrain or Manual mode), you can start the training.
|
| 58 |
+
To start the training, please confirm the estimated cost and click on the `Create Project` button.
|
| 59 |
+
|
| 60 |
+
|
autotrain-advanced/examples/text_classification_binary.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from uuid import uuid4
|
| 3 |
+
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
|
| 6 |
+
from autotrain.dataset import AutoTrainDataset
|
| 7 |
+
from autotrain.project import Project
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
RANDOM_ID = str(uuid4())
|
| 11 |
+
DATASET = "imdb"
|
| 12 |
+
PROJECT_NAME = f"imdb_{RANDOM_ID}"
|
| 13 |
+
TASK = "text_binary_classification"
|
| 14 |
+
MODEL = "bert-base-uncased"
|
| 15 |
+
|
| 16 |
+
USERNAME = os.environ["AUTOTRAIN_USERNAME"]
|
| 17 |
+
TOKEN = os.environ["HF_TOKEN"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
dataset = load_dataset(DATASET)
|
| 22 |
+
train = dataset["train"]
|
| 23 |
+
validation = dataset["test"]
|
| 24 |
+
|
| 25 |
+
# convert to pandas dataframe
|
| 26 |
+
train_df = train.to_pandas()
|
| 27 |
+
validation_df = validation.to_pandas()
|
| 28 |
+
|
| 29 |
+
# prepare dataset for AutoTrain
|
| 30 |
+
dset = AutoTrainDataset(
|
| 31 |
+
train_data=[train_df],
|
| 32 |
+
valid_data=[validation_df],
|
| 33 |
+
task=TASK,
|
| 34 |
+
token=TOKEN,
|
| 35 |
+
project_name=PROJECT_NAME,
|
| 36 |
+
username=USERNAME,
|
| 37 |
+
column_mapping={"text": "text", "label": "label"},
|
| 38 |
+
percent_valid=None,
|
| 39 |
+
)
|
| 40 |
+
dset.prepare()
|
| 41 |
+
|
| 42 |
+
#
|
| 43 |
+
# How to get params for a task:
|
| 44 |
+
#
|
| 45 |
+
# from autotrain.params import Params
|
| 46 |
+
# params = Params(task=TASK, training_type="hub_model").get()
|
| 47 |
+
# print(params) to get full list of params for the task
|
| 48 |
+
|
| 49 |
+
# define params in proper format
|
| 50 |
+
job1 = {
|
| 51 |
+
"task": TASK,
|
| 52 |
+
"learning_rate": 1e-5,
|
| 53 |
+
"optimizer": "adamw_torch",
|
| 54 |
+
"scheduler": "linear",
|
| 55 |
+
"epochs": 5,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
job2 = {
|
| 59 |
+
"task": TASK,
|
| 60 |
+
"learning_rate": 3e-5,
|
| 61 |
+
"optimizer": "adamw_torch",
|
| 62 |
+
"scheduler": "cosine",
|
| 63 |
+
"epochs": 5,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
job3 = {
|
| 67 |
+
"task": TASK,
|
| 68 |
+
"learning_rate": 5e-5,
|
| 69 |
+
"optimizer": "sgd",
|
| 70 |
+
"scheduler": "cosine",
|
| 71 |
+
"epochs": 5,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
jobs = [job1, job2, job3]
|
| 75 |
+
project = Project(dataset=dset, hub_model=MODEL, job_params=jobs)
|
| 76 |
+
project_id = project.create()
|
| 77 |
+
project.approve(project_id)
|
autotrain-advanced/examples/text_classification_multiclass.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from uuid import uuid4
|
| 3 |
+
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
|
| 6 |
+
from autotrain.dataset import AutoTrainDataset
|
| 7 |
+
from autotrain.project import Project
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
RANDOM_ID = str(uuid4())
|
| 11 |
+
DATASET = "amazon_reviews_multi"
|
| 12 |
+
PROJECT_NAME = f"amazon_reviews_multi_{RANDOM_ID}"
|
| 13 |
+
TASK = "text_multi_class_classification"
|
| 14 |
+
MODEL = "bert-base-uncased"
|
| 15 |
+
|
| 16 |
+
USERNAME = os.environ["AUTOTRAIN_USERNAME"]
|
| 17 |
+
TOKEN = os.environ["HF_TOKEN"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
dataset = load_dataset(DATASET, "en")
|
| 22 |
+
train = dataset["train"]
|
| 23 |
+
validation = dataset["test"]
|
| 24 |
+
|
| 25 |
+
# convert to pandas dataframe
|
| 26 |
+
train_df = train.to_pandas()
|
| 27 |
+
validation_df = validation.to_pandas()
|
| 28 |
+
|
| 29 |
+
# prepare dataset for AutoTrain
|
| 30 |
+
dset = AutoTrainDataset(
|
| 31 |
+
train_data=[train_df],
|
| 32 |
+
valid_data=[validation_df],
|
| 33 |
+
task=TASK,
|
| 34 |
+
token=TOKEN,
|
| 35 |
+
project_name=PROJECT_NAME,
|
| 36 |
+
username=USERNAME,
|
| 37 |
+
column_mapping={"text": "review_body", "label": "stars"},
|
| 38 |
+
percent_valid=None,
|
| 39 |
+
)
|
| 40 |
+
dset.prepare()
|
| 41 |
+
|
| 42 |
+
#
|
| 43 |
+
# How to get params for a task:
|
| 44 |
+
#
|
| 45 |
+
# from autotrain.params import Params
|
| 46 |
+
# params = Params(task=TASK, training_type="hub_model").get()
|
| 47 |
+
# print(params) to get full list of params for the task
|
| 48 |
+
|
| 49 |
+
# define params in proper format
|
| 50 |
+
job1 = {
|
| 51 |
+
"task": TASK,
|
| 52 |
+
"learning_rate": 1e-5,
|
| 53 |
+
"optimizer": "adamw_torch",
|
| 54 |
+
"scheduler": "linear",
|
| 55 |
+
"epochs": 5,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
job2 = {
|
| 59 |
+
"task": TASK,
|
| 60 |
+
"learning_rate": 3e-5,
|
| 61 |
+
"optimizer": "adamw_torch",
|
| 62 |
+
"scheduler": "cosine",
|
| 63 |
+
"epochs": 5,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
job3 = {
|
| 67 |
+
"task": TASK,
|
| 68 |
+
"learning_rate": 5e-5,
|
| 69 |
+
"optimizer": "sgd",
|
| 70 |
+
"scheduler": "cosine",
|
| 71 |
+
"epochs": 5,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
jobs = [job1, job2, job3]
|
| 75 |
+
project = Project(dataset=dset, hub_model=MODEL, job_params=jobs)
|
| 76 |
+
project_id = project.create()
|
| 77 |
+
project.approve(project_id)
|
autotrain-advanced/requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
albumentations==1.3.1
|
| 2 |
+
codecarbon==2.2.3
|
| 3 |
+
datasets[vision]~=2.14.0
|
| 4 |
+
evaluate==0.3.0
|
| 5 |
+
ipadic==1.0.0
|
| 6 |
+
jiwer==3.0.2
|
| 7 |
+
joblib==1.3.1
|
| 8 |
+
loguru==0.7.0
|
| 9 |
+
pandas==2.0.3
|
| 10 |
+
Pillow==10.0.0
|
| 11 |
+
protobuf==4.23.4
|
| 12 |
+
pydantic==1.10.11
|
| 13 |
+
sacremoses==0.0.53
|
| 14 |
+
scikit-learn==1.3.0
|
| 15 |
+
sentencepiece==0.1.99
|
| 16 |
+
tqdm==4.65.0
|
| 17 |
+
werkzeug==2.3.6
|
| 18 |
+
huggingface_hub>=0.16.4
|
| 19 |
+
requests==2.31.0
|
| 20 |
+
gradio==3.39.0
|
| 21 |
+
einops==0.6.1
|
| 22 |
+
invisible-watermark==0.2.0
|
| 23 |
+
# latest versions
|
| 24 |
+
tensorboard
|
| 25 |
+
peft
|
| 26 |
+
trl
|
| 27 |
+
tiktoken
|
| 28 |
+
transformers
|
| 29 |
+
accelerate
|
| 30 |
+
diffusers
|
| 31 |
+
bitsandbytes
|
autotrain-advanced/setup.cfg
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[metadata]
|
| 2 |
+
license_files = LICENSE
|
| 3 |
+
version = attr: autotrain.__version__
|
| 4 |
+
|
| 5 |
+
[isort]
|
| 6 |
+
ensure_newline_before_comments = True
|
| 7 |
+
force_grid_wrap = 0
|
| 8 |
+
include_trailing_comma = True
|
| 9 |
+
line_length = 119
|
| 10 |
+
lines_after_imports = 2
|
| 11 |
+
multi_line_output = 3
|
| 12 |
+
use_parentheses = True
|
| 13 |
+
|
| 14 |
+
[flake8]
|
| 15 |
+
ignore = E203, E501, W503
|
| 16 |
+
max-line-length = 119
|
| 17 |
+
per-file-ignores =
|
| 18 |
+
# imported but unused
|
| 19 |
+
__init__.py: F401
|
| 20 |
+
exclude =
|
| 21 |
+
.git,
|
| 22 |
+
.venv,
|
| 23 |
+
__pycache__,
|
| 24 |
+
dist
|
autotrain-advanced/setup.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Lint as: python3
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace / AutoTrain Advanced
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from setuptools import find_packages, setup
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DOCLINES = __doc__.split("\n")
|
| 11 |
+
|
| 12 |
+
this_directory = os.path.abspath(os.path.dirname(__file__))
|
| 13 |
+
with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f:
|
| 14 |
+
LONG_DESCRIPTION = f.read()
|
| 15 |
+
|
| 16 |
+
# get INSTALL_REQUIRES from requirements.txt
|
| 17 |
+
with open(os.path.join(this_directory, "requirements.txt"), encoding="utf-8") as f:
|
| 18 |
+
INSTALL_REQUIRES = f.read().splitlines()
|
| 19 |
+
|
| 20 |
+
QUALITY_REQUIRE = [
|
| 21 |
+
"black",
|
| 22 |
+
"isort",
|
| 23 |
+
"flake8==3.7.9",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
TESTS_REQUIRE = ["pytest"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
EXTRAS_REQUIRE = {
|
| 30 |
+
"dev": INSTALL_REQUIRES + QUALITY_REQUIRE + TESTS_REQUIRE,
|
| 31 |
+
"quality": INSTALL_REQUIRES + QUALITY_REQUIRE,
|
| 32 |
+
"docs": INSTALL_REQUIRES
|
| 33 |
+
+ [
|
| 34 |
+
"recommonmark",
|
| 35 |
+
"sphinx==3.1.2",
|
| 36 |
+
"sphinx-markdown-tables",
|
| 37 |
+
"sphinx-rtd-theme==0.4.3",
|
| 38 |
+
"sphinx-copybutton",
|
| 39 |
+
],
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
setup(
|
| 43 |
+
name="autotrain-advanced",
|
| 44 |
+
description=DOCLINES[0],
|
| 45 |
+
long_description=LONG_DESCRIPTION,
|
| 46 |
+
long_description_content_type="text/markdown",
|
| 47 |
+
author="HuggingFace Inc.",
|
| 48 |
+
author_email="autotrain@huggingface.co",
|
| 49 |
+
url="https://github.com/huggingface/autotrain-advanced",
|
| 50 |
+
download_url="https://github.com/huggingface/autotrain-advanced/tags",
|
| 51 |
+
license="Apache 2.0",
|
| 52 |
+
package_dir={"": "src"},
|
| 53 |
+
packages=find_packages("src"),
|
| 54 |
+
extras_require=EXTRAS_REQUIRE,
|
| 55 |
+
install_requires=INSTALL_REQUIRES,
|
| 56 |
+
entry_points={"console_scripts": ["autotrain=autotrain.cli.autotrain:main"]},
|
| 57 |
+
classifiers=[
|
| 58 |
+
"Development Status :: 5 - Production/Stable",
|
| 59 |
+
"Intended Audience :: Developers",
|
| 60 |
+
"Intended Audience :: Education",
|
| 61 |
+
"Intended Audience :: Science/Research",
|
| 62 |
+
"License :: OSI Approved :: Apache Software License",
|
| 63 |
+
"Operating System :: OS Independent",
|
| 64 |
+
"Programming Language :: Python :: 3.8",
|
| 65 |
+
"Programming Language :: Python :: 3.9",
|
| 66 |
+
"Programming Language :: Python :: 3.10",
|
| 67 |
+
"Programming Language :: Python :: 3.11",
|
| 68 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 69 |
+
],
|
| 70 |
+
keywords="automl autonlp autotrain huggingface",
|
| 71 |
+
)
|
autotrain-advanced/src/autotrain/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020-2021 The HuggingFace AutoTrain Authors
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Lint as: python3
|
| 17 |
+
# pylint: enable=line-too-long
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ignore bnb warnings
|
| 22 |
+
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
| 23 |
+
# os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 24 |
+
__version__ = "0.6.16.dev0"
|
autotrain-advanced/src/autotrain/app.py
ADDED
|
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import string
|
| 5 |
+
import zipfile
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from huggingface_hub import list_models
|
| 10 |
+
from loguru import logger
|
| 11 |
+
|
| 12 |
+
from autotrain.dataset import AutoTrainDataset, AutoTrainDreamboothDataset, AutoTrainImageClassificationDataset
|
| 13 |
+
from autotrain.languages import SUPPORTED_LANGUAGES
|
| 14 |
+
from autotrain.params import Params
|
| 15 |
+
from autotrain.project import Project
|
| 16 |
+
from autotrain.utils import get_project_cost, get_user_token, user_authentication
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
APP_TASKS = {
|
| 20 |
+
"Natural Language Processing": ["Text Classification", "LLM Finetuning"],
|
| 21 |
+
# "Tabular": TABULAR_TASKS,
|
| 22 |
+
"Computer Vision": ["Image Classification", "Dreambooth"],
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
APP_TASKS_MAPPING = {
|
| 26 |
+
"Text Classification": "text_multi_class_classification",
|
| 27 |
+
"LLM Finetuning": "lm_training",
|
| 28 |
+
"Image Classification": "image_multi_class_classification",
|
| 29 |
+
"Dreambooth": "dreambooth",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
APP_TASK_TYPE_MAPPING = {
|
| 33 |
+
"text_classification": "Natural Language Processing",
|
| 34 |
+
"lm_training": "Natural Language Processing",
|
| 35 |
+
"image_classification": "Computer Vision",
|
| 36 |
+
"dreambooth": "Computer Vision",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
ALLOWED_FILE_TYPES = [
|
| 40 |
+
".csv",
|
| 41 |
+
".CSV",
|
| 42 |
+
".jsonl",
|
| 43 |
+
".JSONL",
|
| 44 |
+
".zip",
|
| 45 |
+
".ZIP",
|
| 46 |
+
".png",
|
| 47 |
+
".PNG",
|
| 48 |
+
".jpg",
|
| 49 |
+
".JPG",
|
| 50 |
+
".jpeg",
|
| 51 |
+
".JPEG",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _login_user(user_token):
|
| 56 |
+
user_info = user_authentication(token=user_token)
|
| 57 |
+
username = user_info["name"]
|
| 58 |
+
|
| 59 |
+
user_can_pay = user_info["canPay"]
|
| 60 |
+
orgs = user_info["orgs"]
|
| 61 |
+
|
| 62 |
+
valid_orgs = [org for org in orgs if org["canPay"] is True]
|
| 63 |
+
valid_orgs = [org for org in valid_orgs if org["roleInOrg"] in ("admin", "write")]
|
| 64 |
+
valid_orgs = [org["name"] for org in valid_orgs]
|
| 65 |
+
|
| 66 |
+
valid_can_pay = [username] + valid_orgs if user_can_pay else valid_orgs
|
| 67 |
+
who_is_training = [username] + [org["name"] for org in orgs]
|
| 68 |
+
return user_token, valid_can_pay, who_is_training
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _update_task_type(project_type):
|
| 72 |
+
return gr.Dropdown.update(
|
| 73 |
+
value=APP_TASKS[project_type][0],
|
| 74 |
+
choices=APP_TASKS[project_type],
|
| 75 |
+
visible=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _update_model_choice(task, autotrain_backend):
|
| 80 |
+
# TODO: add tabular and remember, for tabular, we only support AutoTrain
|
| 81 |
+
if autotrain_backend.lower() != "huggingface internal":
|
| 82 |
+
model_choice = ["HuggingFace Hub"]
|
| 83 |
+
return gr.Dropdown.update(
|
| 84 |
+
value=model_choice[0],
|
| 85 |
+
choices=model_choice,
|
| 86 |
+
visible=True,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if task == "LLM Finetuning":
|
| 90 |
+
model_choice = ["HuggingFace Hub"]
|
| 91 |
+
else:
|
| 92 |
+
model_choice = ["AutoTrain", "HuggingFace Hub"]
|
| 93 |
+
|
| 94 |
+
return gr.Dropdown.update(
|
| 95 |
+
value=model_choice[0],
|
| 96 |
+
choices=model_choice,
|
| 97 |
+
visible=True,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _update_file_type(task):
|
| 102 |
+
task = APP_TASKS_MAPPING[task]
|
| 103 |
+
if task in ("text_multi_class_classification", "lm_training"):
|
| 104 |
+
return gr.Radio.update(
|
| 105 |
+
value="CSV",
|
| 106 |
+
choices=["CSV", "JSONL"],
|
| 107 |
+
visible=True,
|
| 108 |
+
)
|
| 109 |
+
elif task == "image_multi_class_classification":
|
| 110 |
+
return gr.Radio.update(
|
| 111 |
+
value="ZIP",
|
| 112 |
+
choices=["Image Subfolders", "ZIP"],
|
| 113 |
+
visible=True,
|
| 114 |
+
)
|
| 115 |
+
elif task == "dreambooth":
|
| 116 |
+
return gr.Radio.update(
|
| 117 |
+
value="ZIP",
|
| 118 |
+
choices=["Image Folder", "ZIP"],
|
| 119 |
+
visible=True,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
raise NotImplementedError
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _update_param_choice(model_choice, autotrain_backend):
|
| 126 |
+
logger.info(f"model_choice: {model_choice}")
|
| 127 |
+
choices = ["AutoTrain", "Manual"] if model_choice == "HuggingFace Hub" else ["AutoTrain"]
|
| 128 |
+
choices = ["Manual"] if autotrain_backend != "HuggingFace Internal" else choices
|
| 129 |
+
return gr.Dropdown.update(
|
| 130 |
+
value=choices[0],
|
| 131 |
+
choices=choices,
|
| 132 |
+
visible=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _project_type_update(project_type, task_type, autotrain_backend):
|
| 137 |
+
logger.info(f"project_type: {project_type}, task_type: {task_type}")
|
| 138 |
+
task_choices_update = _update_task_type(project_type)
|
| 139 |
+
model_choices_update = _update_model_choice(task_choices_update["value"], autotrain_backend)
|
| 140 |
+
param_choices_update = _update_param_choice(model_choices_update["value"], autotrain_backend)
|
| 141 |
+
return [
|
| 142 |
+
task_choices_update,
|
| 143 |
+
model_choices_update,
|
| 144 |
+
param_choices_update,
|
| 145 |
+
_update_hub_model_choices(task_choices_update["value"], model_choices_update["value"]),
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _task_type_update(task_type, autotrain_backend):
|
| 150 |
+
logger.info(f"task_type: {task_type}")
|
| 151 |
+
model_choices_update = _update_model_choice(task_type, autotrain_backend)
|
| 152 |
+
param_choices_update = _update_param_choice(model_choices_update["value"], autotrain_backend)
|
| 153 |
+
return [
|
| 154 |
+
model_choices_update,
|
| 155 |
+
param_choices_update,
|
| 156 |
+
_update_hub_model_choices(task_type, model_choices_update["value"]),
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _update_col_map(training_data, task):
|
| 161 |
+
task = APP_TASKS_MAPPING[task]
|
| 162 |
+
if task == "text_multi_class_classification":
|
| 163 |
+
data_cols = pd.read_csv(training_data[0].name, nrows=2).columns.tolist()
|
| 164 |
+
return [
|
| 165 |
+
gr.Dropdown.update(visible=True, choices=data_cols, label="Map `text` column", value=data_cols[0]),
|
| 166 |
+
gr.Dropdown.update(visible=True, choices=data_cols, label="Map `target` column", value=data_cols[1]),
|
| 167 |
+
gr.Text.update(visible=False),
|
| 168 |
+
]
|
| 169 |
+
elif task == "lm_training":
|
| 170 |
+
data_cols = pd.read_csv(training_data[0].name, nrows=2).columns.tolist()
|
| 171 |
+
return [
|
| 172 |
+
gr.Dropdown.update(visible=True, choices=data_cols, label="Map `text` column", value=data_cols[0]),
|
| 173 |
+
gr.Dropdown.update(visible=False),
|
| 174 |
+
gr.Text.update(visible=False),
|
| 175 |
+
]
|
| 176 |
+
elif task == "dreambooth":
|
| 177 |
+
return [
|
| 178 |
+
gr.Dropdown.update(visible=False),
|
| 179 |
+
gr.Dropdown.update(visible=False),
|
| 180 |
+
gr.Text.update(visible=True, label="Concept Token", interactive=True),
|
| 181 |
+
]
|
| 182 |
+
else:
|
| 183 |
+
return [
|
| 184 |
+
gr.Dropdown.update(visible=False),
|
| 185 |
+
gr.Dropdown.update(visible=False),
|
| 186 |
+
gr.Text.update(visible=False),
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _estimate_costs(
|
| 191 |
+
training_data, validation_data, task, user_token, autotrain_username, training_params_txt, autotrain_backend
|
| 192 |
+
):
|
| 193 |
+
if autotrain_backend.lower() != "huggingface internal":
|
| 194 |
+
return [
|
| 195 |
+
gr.Markdown.update(
|
| 196 |
+
value="Cost estimation is not available for this backend",
|
| 197 |
+
visible=True,
|
| 198 |
+
),
|
| 199 |
+
gr.Number.update(visible=False),
|
| 200 |
+
]
|
| 201 |
+
try:
|
| 202 |
+
logger.info("Estimating costs....")
|
| 203 |
+
if training_data is None:
|
| 204 |
+
return [
|
| 205 |
+
gr.Markdown.update(
|
| 206 |
+
value="Could not estimate cost. Please add training data",
|
| 207 |
+
visible=True,
|
| 208 |
+
),
|
| 209 |
+
gr.Number.update(visible=False),
|
| 210 |
+
]
|
| 211 |
+
if validation_data is None:
|
| 212 |
+
validation_data = []
|
| 213 |
+
|
| 214 |
+
training_params = json.loads(training_params_txt)
|
| 215 |
+
if len(training_params) == 0:
|
| 216 |
+
return [
|
| 217 |
+
gr.Markdown.update(
|
| 218 |
+
value="Could not estimate cost. Please add atleast one job",
|
| 219 |
+
visible=True,
|
| 220 |
+
),
|
| 221 |
+
gr.Number.update(visible=False),
|
| 222 |
+
]
|
| 223 |
+
elif len(training_params) == 1:
|
| 224 |
+
if "num_models" in training_params[0]:
|
| 225 |
+
num_models = training_params[0]["num_models"]
|
| 226 |
+
else:
|
| 227 |
+
num_models = 1
|
| 228 |
+
else:
|
| 229 |
+
num_models = len(training_params)
|
| 230 |
+
task = APP_TASKS_MAPPING[task]
|
| 231 |
+
num_samples = 0
|
| 232 |
+
logger.info("Estimating number of samples")
|
| 233 |
+
if task in ("text_multi_class_classification", "lm_training"):
|
| 234 |
+
for _f in training_data:
|
| 235 |
+
num_samples += pd.read_csv(_f.name).shape[0]
|
| 236 |
+
for _f in validation_data:
|
| 237 |
+
num_samples += pd.read_csv(_f.name).shape[0]
|
| 238 |
+
elif task == "image_multi_class_classification":
|
| 239 |
+
logger.info(f"training_data: {training_data}")
|
| 240 |
+
if len(training_data) > 1:
|
| 241 |
+
return [
|
| 242 |
+
gr.Markdown.update(
|
| 243 |
+
value="Only one training file is supported for image classification",
|
| 244 |
+
visible=True,
|
| 245 |
+
),
|
| 246 |
+
gr.Number.update(visible=False),
|
| 247 |
+
]
|
| 248 |
+
if len(validation_data) > 1:
|
| 249 |
+
return [
|
| 250 |
+
gr.Markdown.update(
|
| 251 |
+
value="Only one validation file is supported for image classification",
|
| 252 |
+
visible=True,
|
| 253 |
+
),
|
| 254 |
+
gr.Number.update(visible=False),
|
| 255 |
+
]
|
| 256 |
+
for _f in training_data:
|
| 257 |
+
zip_ref = zipfile.ZipFile(_f.name, "r")
|
| 258 |
+
for _ in zip_ref.namelist():
|
| 259 |
+
num_samples += 1
|
| 260 |
+
for _f in validation_data:
|
| 261 |
+
zip_ref = zipfile.ZipFile(_f.name, "r")
|
| 262 |
+
for _ in zip_ref.namelist():
|
| 263 |
+
num_samples += 1
|
| 264 |
+
elif task == "dreambooth":
|
| 265 |
+
num_samples = len(training_data)
|
| 266 |
+
else:
|
| 267 |
+
raise NotImplementedError
|
| 268 |
+
|
| 269 |
+
logger.info(f"Estimating costs for: num_models: {num_models}, task: {task}, num_samples: {num_samples}")
|
| 270 |
+
estimated_cost = get_project_cost(
|
| 271 |
+
username=autotrain_username,
|
| 272 |
+
token=user_token,
|
| 273 |
+
task=task,
|
| 274 |
+
num_samples=num_samples,
|
| 275 |
+
num_models=num_models,
|
| 276 |
+
)
|
| 277 |
+
logger.info(f"Estimated_cost: {estimated_cost}")
|
| 278 |
+
return [
|
| 279 |
+
gr.Markdown.update(
|
| 280 |
+
value=f"Estimated cost: ${estimated_cost:.2f}. Note: clicking on 'Create Project' will start training and incur charges!",
|
| 281 |
+
visible=True,
|
| 282 |
+
),
|
| 283 |
+
gr.Number.update(visible=False),
|
| 284 |
+
]
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(e)
|
| 287 |
+
logger.error("Could not estimate cost, check inputs")
|
| 288 |
+
return [
|
| 289 |
+
gr.Markdown.update(
|
| 290 |
+
value="Could not estimate cost, check inputs",
|
| 291 |
+
visible=True,
|
| 292 |
+
),
|
| 293 |
+
gr.Number.update(visible=False),
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def get_job_params(param_choice, training_params, task):
|
| 298 |
+
if param_choice == "autotrain":
|
| 299 |
+
if len(training_params) > 1:
|
| 300 |
+
raise ValueError("❌ Only one job parameter is allowed for AutoTrain.")
|
| 301 |
+
training_params[0].update({"task": task})
|
| 302 |
+
elif param_choice.lower() == "manual":
|
| 303 |
+
for i in range(len(training_params)):
|
| 304 |
+
training_params[i].update({"task": task})
|
| 305 |
+
if "hub_model" in training_params[i]:
|
| 306 |
+
# remove hub_model from training_params
|
| 307 |
+
training_params[i].pop("hub_model")
|
| 308 |
+
return training_params
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _update_project_name():
|
| 312 |
+
random_project_name = "-".join(
|
| 313 |
+
["".join(random.choices(string.ascii_lowercase + string.digits, k=4)) for _ in range(3)]
|
| 314 |
+
)
|
| 315 |
+
# check if training tracker exists
|
| 316 |
+
if os.path.exists(os.path.join("/tmp", "training")):
|
| 317 |
+
return [
|
| 318 |
+
gr.Text.update(value=random_project_name, visible=True, interactive=True),
|
| 319 |
+
gr.Button.update(interactive=False),
|
| 320 |
+
]
|
| 321 |
+
return [
|
| 322 |
+
gr.Text.update(value=random_project_name, visible=True, interactive=True),
|
| 323 |
+
gr.Button.update(interactive=True),
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _update_hub_model_choices(task, model_choice):
|
| 328 |
+
task = APP_TASKS_MAPPING[task]
|
| 329 |
+
logger.info(f"Updating hub model choices for task: {task}, model_choice: {model_choice}")
|
| 330 |
+
if model_choice.lower() == "autotrain":
|
| 331 |
+
return gr.Dropdown.update(
|
| 332 |
+
visible=False,
|
| 333 |
+
interactive=False,
|
| 334 |
+
)
|
| 335 |
+
if task == "text_multi_class_classification":
|
| 336 |
+
hub_models1 = list_models(filter="fill-mask", sort="downloads", direction=-1, limit=100)
|
| 337 |
+
hub_models2 = list_models(filter="text-classification", sort="downloads", direction=-1, limit=100)
|
| 338 |
+
hub_models = list(hub_models1) + list(hub_models2)
|
| 339 |
+
elif task == "lm_training":
|
| 340 |
+
hub_models = list(list_models(filter="text-generation", sort="downloads", direction=-1, limit=100))
|
| 341 |
+
elif task == "image_multi_class_classification":
|
| 342 |
+
hub_models = list(list_models(filter="image-classification", sort="downloads", direction=-1, limit=100))
|
| 343 |
+
elif task == "dreambooth":
|
| 344 |
+
hub_models = list(list_models(filter="text-to-image", sort="downloads", direction=-1, limit=100))
|
| 345 |
+
else:
|
| 346 |
+
raise NotImplementedError
|
| 347 |
+
# sort by number of downloads in descending order
|
| 348 |
+
hub_models = [{"id": m.modelId, "downloads": m.downloads} for m in hub_models if m.private is False]
|
| 349 |
+
hub_models = sorted(hub_models, key=lambda x: x["downloads"], reverse=True)
|
| 350 |
+
|
| 351 |
+
if task == "dreambooth":
|
| 352 |
+
choices = ["stabilityai/stable-diffusion-xl-base-1.0"] + [m["id"] for m in hub_models]
|
| 353 |
+
value = choices[0]
|
| 354 |
+
return gr.Dropdown.update(
|
| 355 |
+
choices=choices,
|
| 356 |
+
value=value,
|
| 357 |
+
visible=True,
|
| 358 |
+
interactive=True,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
return gr.Dropdown.update(
|
| 362 |
+
choices=[m["id"] for m in hub_models],
|
| 363 |
+
value=hub_models[0]["id"],
|
| 364 |
+
visible=True,
|
| 365 |
+
interactive=True,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _update_backend(backend):
|
| 370 |
+
if backend != "Hugging Face Internal":
|
| 371 |
+
return [
|
| 372 |
+
gr.Dropdown.update(
|
| 373 |
+
visible=True,
|
| 374 |
+
interactive=True,
|
| 375 |
+
choices=["HuggingFace Hub"],
|
| 376 |
+
value="HuggingFace Hub",
|
| 377 |
+
),
|
| 378 |
+
gr.Dropdown.update(
|
| 379 |
+
visible=True,
|
| 380 |
+
interactive=True,
|
| 381 |
+
choices=["Manual"],
|
| 382 |
+
value="Manual",
|
| 383 |
+
),
|
| 384 |
+
]
|
| 385 |
+
return [
|
| 386 |
+
gr.Dropdown.update(
|
| 387 |
+
visible=True,
|
| 388 |
+
interactive=True,
|
| 389 |
+
),
|
| 390 |
+
gr.Dropdown.update(
|
| 391 |
+
visible=True,
|
| 392 |
+
interactive=True,
|
| 393 |
+
),
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _create_project(
|
| 398 |
+
autotrain_username,
|
| 399 |
+
valid_can_pay,
|
| 400 |
+
project_name,
|
| 401 |
+
user_token,
|
| 402 |
+
task,
|
| 403 |
+
training_data,
|
| 404 |
+
validation_data,
|
| 405 |
+
col_map_text,
|
| 406 |
+
col_map_label,
|
| 407 |
+
concept_token,
|
| 408 |
+
training_params_txt,
|
| 409 |
+
hub_model,
|
| 410 |
+
estimated_cost,
|
| 411 |
+
autotrain_backend,
|
| 412 |
+
):
|
| 413 |
+
task = APP_TASKS_MAPPING[task]
|
| 414 |
+
valid_can_pay = valid_can_pay.split(",")
|
| 415 |
+
can_pay = autotrain_username in valid_can_pay
|
| 416 |
+
logger.info(f"🚨🚨🚨Creating project: {project_name}")
|
| 417 |
+
logger.info(f"🚨Task: {task}")
|
| 418 |
+
logger.info(f"🚨Training data: {training_data}")
|
| 419 |
+
logger.info(f"🚨Validation data: {validation_data}")
|
| 420 |
+
logger.info(f"🚨Training params: {training_params_txt}")
|
| 421 |
+
logger.info(f"🚨Hub model: {hub_model}")
|
| 422 |
+
logger.info(f"🚨Estimated cost: {estimated_cost}")
|
| 423 |
+
logger.info(f"🚨:Can pay: {can_pay}")
|
| 424 |
+
|
| 425 |
+
if can_pay is False and estimated_cost > 0:
|
| 426 |
+
raise gr.Error("❌ You do not have enough credits to create this project. Please add a valid payment method.")
|
| 427 |
+
|
| 428 |
+
training_params = json.loads(training_params_txt)
|
| 429 |
+
if len(training_params) == 0:
|
| 430 |
+
raise gr.Error("Please add atleast one job")
|
| 431 |
+
elif len(training_params) == 1:
|
| 432 |
+
if "num_models" in training_params[0]:
|
| 433 |
+
param_choice = "autotrain"
|
| 434 |
+
else:
|
| 435 |
+
param_choice = "manual"
|
| 436 |
+
else:
|
| 437 |
+
param_choice = "manual"
|
| 438 |
+
|
| 439 |
+
if task == "image_multi_class_classification":
|
| 440 |
+
training_data = training_data[0].name
|
| 441 |
+
if validation_data is not None:
|
| 442 |
+
validation_data = validation_data[0].name
|
| 443 |
+
dset = AutoTrainImageClassificationDataset(
|
| 444 |
+
train_data=training_data,
|
| 445 |
+
token=user_token,
|
| 446 |
+
project_name=project_name,
|
| 447 |
+
username=autotrain_username,
|
| 448 |
+
valid_data=validation_data,
|
| 449 |
+
percent_valid=None, # TODO: add to UI
|
| 450 |
+
)
|
| 451 |
+
elif task == "text_multi_class_classification":
|
| 452 |
+
training_data = [f.name for f in training_data]
|
| 453 |
+
if validation_data is None:
|
| 454 |
+
validation_data = []
|
| 455 |
+
else:
|
| 456 |
+
validation_data = [f.name for f in validation_data]
|
| 457 |
+
dset = AutoTrainDataset(
|
| 458 |
+
train_data=training_data,
|
| 459 |
+
task=task,
|
| 460 |
+
token=user_token,
|
| 461 |
+
project_name=project_name,
|
| 462 |
+
username=autotrain_username,
|
| 463 |
+
column_mapping={"text": col_map_text, "label": col_map_label},
|
| 464 |
+
valid_data=validation_data,
|
| 465 |
+
percent_valid=None, # TODO: add to UI
|
| 466 |
+
)
|
| 467 |
+
elif task == "lm_training":
|
| 468 |
+
training_data = [f.name for f in training_data]
|
| 469 |
+
if validation_data is None:
|
| 470 |
+
validation_data = []
|
| 471 |
+
else:
|
| 472 |
+
validation_data = [f.name for f in validation_data]
|
| 473 |
+
dset = AutoTrainDataset(
|
| 474 |
+
train_data=training_data,
|
| 475 |
+
task=task,
|
| 476 |
+
token=user_token,
|
| 477 |
+
project_name=project_name,
|
| 478 |
+
username=autotrain_username,
|
| 479 |
+
column_mapping={"text": col_map_text},
|
| 480 |
+
valid_data=validation_data,
|
| 481 |
+
percent_valid=None, # TODO: add to UI
|
| 482 |
+
)
|
| 483 |
+
elif task == "dreambooth":
|
| 484 |
+
dset = AutoTrainDreamboothDataset(
|
| 485 |
+
concept_images=training_data,
|
| 486 |
+
concept_name=concept_token,
|
| 487 |
+
token=user_token,
|
| 488 |
+
project_name=project_name,
|
| 489 |
+
username=autotrain_username,
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
raise NotImplementedError
|
| 493 |
+
|
| 494 |
+
dset.prepare()
|
| 495 |
+
project = Project(
|
| 496 |
+
dataset=dset,
|
| 497 |
+
param_choice=param_choice,
|
| 498 |
+
hub_model=hub_model,
|
| 499 |
+
job_params=get_job_params(param_choice, training_params, task),
|
| 500 |
+
)
|
| 501 |
+
if autotrain_backend.lower() == "huggingface internal":
|
| 502 |
+
project_id = project.create()
|
| 503 |
+
project.approve(project_id)
|
| 504 |
+
return gr.Markdown.update(
|
| 505 |
+
value=f"Project created successfully. Monitor progess on the [dashboard](https://ui.autotrain.huggingface.co/{project_id}/trainings).",
|
| 506 |
+
visible=True,
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
project.create(local=True)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def get_variable_name(var, namespace):
|
| 513 |
+
for name in namespace:
|
| 514 |
+
if namespace[name] is var:
|
| 515 |
+
return name
|
| 516 |
+
return None
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def disable_create_project_button():
|
| 520 |
+
return gr.Button.update(interactive=False)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def main():
|
| 524 |
+
with gr.Blocks(theme="freddyaboulton/dracula_revamped") as demo:
|
| 525 |
+
gr.Markdown("## 🤗 AutoTrain Advanced")
|
| 526 |
+
user_token = os.environ.get("HF_TOKEN", "")
|
| 527 |
+
|
| 528 |
+
if len(user_token) == 0:
|
| 529 |
+
user_token = get_user_token()
|
| 530 |
+
|
| 531 |
+
if user_token is None:
|
| 532 |
+
gr.Markdown(
|
| 533 |
+
"""Please login with a write [token](https://huggingface.co/settings/tokens).
|
| 534 |
+
Pass your HF token in an environment variable called `HF_TOKEN` and then restart this app.
|
| 535 |
+
"""
|
| 536 |
+
)
|
| 537 |
+
return demo
|
| 538 |
+
|
| 539 |
+
user_token, valid_can_pay, who_is_training = _login_user(user_token)
|
| 540 |
+
|
| 541 |
+
if user_token is None or len(user_token) == 0:
|
| 542 |
+
gr.Error("Please login with a write token.")
|
| 543 |
+
|
| 544 |
+
user_token = gr.Textbox(
|
| 545 |
+
value=user_token, type="password", lines=1, max_lines=1, visible=False, interactive=False
|
| 546 |
+
)
|
| 547 |
+
valid_can_pay = gr.Textbox(value=",".join(valid_can_pay), visible=False, interactive=False)
|
| 548 |
+
with gr.Row():
|
| 549 |
+
with gr.Column():
|
| 550 |
+
with gr.Row():
|
| 551 |
+
autotrain_username = gr.Dropdown(
|
| 552 |
+
label="AutoTrain Username",
|
| 553 |
+
choices=who_is_training,
|
| 554 |
+
value=who_is_training[0] if who_is_training else "",
|
| 555 |
+
)
|
| 556 |
+
autotrain_backend = gr.Dropdown(
|
| 557 |
+
label="AutoTrain Backend",
|
| 558 |
+
choices=["HuggingFace Internal", "HuggingFace Spaces"],
|
| 559 |
+
value="HuggingFace Internal",
|
| 560 |
+
interactive=True,
|
| 561 |
+
)
|
| 562 |
+
with gr.Row():
|
| 563 |
+
project_name = gr.Textbox(label="Project name", value="", lines=1, max_lines=1, interactive=True)
|
| 564 |
+
project_type = gr.Dropdown(
|
| 565 |
+
label="Project Type", choices=list(APP_TASKS.keys()), value=list(APP_TASKS.keys())[0]
|
| 566 |
+
)
|
| 567 |
+
task_type = gr.Dropdown(
|
| 568 |
+
label="Task",
|
| 569 |
+
choices=APP_TASKS[list(APP_TASKS.keys())[0]],
|
| 570 |
+
value=APP_TASKS[list(APP_TASKS.keys())[0]][0],
|
| 571 |
+
interactive=True,
|
| 572 |
+
)
|
| 573 |
+
model_choice = gr.Dropdown(
|
| 574 |
+
label="Model Choice",
|
| 575 |
+
choices=["AutoTrain", "HuggingFace Hub"],
|
| 576 |
+
value="AutoTrain",
|
| 577 |
+
visible=True,
|
| 578 |
+
interactive=True,
|
| 579 |
+
)
|
| 580 |
+
hub_model = gr.Dropdown(
|
| 581 |
+
label="Hub Model",
|
| 582 |
+
value="",
|
| 583 |
+
visible=False,
|
| 584 |
+
interactive=True,
|
| 585 |
+
elem_id="hub_model",
|
| 586 |
+
)
|
| 587 |
+
gr.Markdown("<hr>")
|
| 588 |
+
with gr.Row():
|
| 589 |
+
with gr.Column():
|
| 590 |
+
with gr.Tabs(elem_id="tabs"):
|
| 591 |
+
with gr.TabItem("Data"):
|
| 592 |
+
with gr.Column():
|
| 593 |
+
# file_type_training = gr.Radio(
|
| 594 |
+
# label="File Type",
|
| 595 |
+
# choices=["CSV", "JSONL"],
|
| 596 |
+
# value="CSV",
|
| 597 |
+
# visible=True,
|
| 598 |
+
# interactive=True,
|
| 599 |
+
# )
|
| 600 |
+
training_data = gr.File(
|
| 601 |
+
label="Training Data",
|
| 602 |
+
file_types=ALLOWED_FILE_TYPES,
|
| 603 |
+
file_count="multiple",
|
| 604 |
+
visible=True,
|
| 605 |
+
interactive=True,
|
| 606 |
+
elem_id="training_data_box",
|
| 607 |
+
)
|
| 608 |
+
with gr.Accordion("Validation Data (Optional)", open=False):
|
| 609 |
+
validation_data = gr.File(
|
| 610 |
+
label="Validation Data (Optional)",
|
| 611 |
+
file_types=ALLOWED_FILE_TYPES,
|
| 612 |
+
file_count="multiple",
|
| 613 |
+
visible=True,
|
| 614 |
+
interactive=True,
|
| 615 |
+
elem_id="validation_data_box",
|
| 616 |
+
)
|
| 617 |
+
with gr.Row():
|
| 618 |
+
col_map_text = gr.Dropdown(
|
| 619 |
+
label="Text Column", choices=[], visible=False, interactive=True
|
| 620 |
+
)
|
| 621 |
+
col_map_target = gr.Dropdown(
|
| 622 |
+
label="Target Column", choices=[], visible=False, interactive=True
|
| 623 |
+
)
|
| 624 |
+
concept_token = gr.Text(
|
| 625 |
+
value="", visible=False, interactive=True, lines=1, max_lines=1
|
| 626 |
+
)
|
| 627 |
+
with gr.TabItem("Params"):
|
| 628 |
+
with gr.Row():
|
| 629 |
+
source_language = gr.Dropdown(
|
| 630 |
+
label="Source Language",
|
| 631 |
+
choices=SUPPORTED_LANGUAGES[:-1],
|
| 632 |
+
value="en",
|
| 633 |
+
visible=True,
|
| 634 |
+
interactive=True,
|
| 635 |
+
elem_id="source_language",
|
| 636 |
+
)
|
| 637 |
+
num_models = gr.Slider(
|
| 638 |
+
label="Number of Models",
|
| 639 |
+
minimum=1,
|
| 640 |
+
maximum=25,
|
| 641 |
+
value=5,
|
| 642 |
+
step=1,
|
| 643 |
+
visible=True,
|
| 644 |
+
interactive=True,
|
| 645 |
+
elem_id="num_models",
|
| 646 |
+
)
|
| 647 |
+
target_language = gr.Dropdown(
|
| 648 |
+
label="Target Language",
|
| 649 |
+
choices=["fr"],
|
| 650 |
+
value="fr",
|
| 651 |
+
visible=False,
|
| 652 |
+
interactive=True,
|
| 653 |
+
elem_id="target_language",
|
| 654 |
+
)
|
| 655 |
+
image_size = gr.Number(
|
| 656 |
+
label="Image Size",
|
| 657 |
+
value=512,
|
| 658 |
+
visible=False,
|
| 659 |
+
interactive=True,
|
| 660 |
+
elem_id="image_size",
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
with gr.Row():
|
| 664 |
+
learning_rate = gr.Number(
|
| 665 |
+
label="Learning Rate",
|
| 666 |
+
value=5e-5,
|
| 667 |
+
visible=False,
|
| 668 |
+
interactive=True,
|
| 669 |
+
elem_id="learning_rate",
|
| 670 |
+
)
|
| 671 |
+
batch_size = gr.Number(
|
| 672 |
+
label="Train Batch Size",
|
| 673 |
+
value=32,
|
| 674 |
+
visible=False,
|
| 675 |
+
interactive=True,
|
| 676 |
+
elem_id="train_batch_size",
|
| 677 |
+
)
|
| 678 |
+
num_epochs = gr.Number(
|
| 679 |
+
label="Number of Epochs",
|
| 680 |
+
value=3,
|
| 681 |
+
visible=False,
|
| 682 |
+
interactive=True,
|
| 683 |
+
elem_id="num_train_epochs",
|
| 684 |
+
)
|
| 685 |
+
with gr.Row():
|
| 686 |
+
gradient_accumulation_steps = gr.Number(
|
| 687 |
+
label="Gradient Accumulation Steps",
|
| 688 |
+
value=1,
|
| 689 |
+
visible=False,
|
| 690 |
+
interactive=True,
|
| 691 |
+
elem_id="gradient_accumulation_steps",
|
| 692 |
+
)
|
| 693 |
+
percentage_warmup_steps = gr.Number(
|
| 694 |
+
label="Percentage of Warmup Steps",
|
| 695 |
+
value=0.1,
|
| 696 |
+
visible=False,
|
| 697 |
+
interactive=True,
|
| 698 |
+
elem_id="percentage_warmup",
|
| 699 |
+
)
|
| 700 |
+
weight_decay = gr.Number(
|
| 701 |
+
label="Weight Decay",
|
| 702 |
+
value=0.01,
|
| 703 |
+
visible=False,
|
| 704 |
+
interactive=True,
|
| 705 |
+
elem_id="weight_decay",
|
| 706 |
+
)
|
| 707 |
+
with gr.Row():
|
| 708 |
+
lora_r = gr.Number(
|
| 709 |
+
label="LoraR",
|
| 710 |
+
value=16,
|
| 711 |
+
visible=False,
|
| 712 |
+
interactive=True,
|
| 713 |
+
elem_id="lora_r",
|
| 714 |
+
)
|
| 715 |
+
lora_alpha = gr.Number(
|
| 716 |
+
label="LoraAlpha",
|
| 717 |
+
value=32,
|
| 718 |
+
visible=False,
|
| 719 |
+
interactive=True,
|
| 720 |
+
elem_id="lora_alpha",
|
| 721 |
+
)
|
| 722 |
+
lora_dropout = gr.Number(
|
| 723 |
+
label="Lora Dropout",
|
| 724 |
+
value=0.1,
|
| 725 |
+
visible=False,
|
| 726 |
+
interactive=True,
|
| 727 |
+
elem_id="lora_dropout",
|
| 728 |
+
)
|
| 729 |
+
with gr.Row():
|
| 730 |
+
db_num_steps = gr.Number(
|
| 731 |
+
label="Num Steps",
|
| 732 |
+
value=500,
|
| 733 |
+
visible=False,
|
| 734 |
+
interactive=True,
|
| 735 |
+
elem_id="num_steps",
|
| 736 |
+
)
|
| 737 |
+
with gr.Row():
|
| 738 |
+
optimizer = gr.Dropdown(
|
| 739 |
+
label="Optimizer",
|
| 740 |
+
choices=["adamw_torch", "adamw_hf", "sgd", "adafactor", "adagrad"],
|
| 741 |
+
value="adamw_torch",
|
| 742 |
+
visible=False,
|
| 743 |
+
interactive=True,
|
| 744 |
+
elem_id="optimizer",
|
| 745 |
+
)
|
| 746 |
+
scheduler = gr.Dropdown(
|
| 747 |
+
label="Scheduler",
|
| 748 |
+
choices=["linear", "cosine"],
|
| 749 |
+
value="linear",
|
| 750 |
+
visible=False,
|
| 751 |
+
interactive=True,
|
| 752 |
+
elem_id="scheduler",
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
add_job_button = gr.Button(
|
| 756 |
+
value="Add Job",
|
| 757 |
+
visible=True,
|
| 758 |
+
interactive=True,
|
| 759 |
+
elem_id="add_job",
|
| 760 |
+
)
|
| 761 |
+
# clear_jobs_button = gr.Button(
|
| 762 |
+
# value="Clear Jobs",
|
| 763 |
+
# visible=True,
|
| 764 |
+
# interactive=True,
|
| 765 |
+
# elem_id="clear_jobs",
|
| 766 |
+
# )
|
| 767 |
+
gr.Markdown("<hr>")
|
| 768 |
+
estimated_costs_md = gr.Markdown(value="Estimated Costs: N/A", visible=True, interactive=False)
|
| 769 |
+
estimated_costs_num = gr.Number(value=0, visible=False, interactive=False)
|
| 770 |
+
create_project_button = gr.Button(
|
| 771 |
+
value="Create Project",
|
| 772 |
+
visible=True,
|
| 773 |
+
interactive=True,
|
| 774 |
+
elem_id="create_project",
|
| 775 |
+
)
|
| 776 |
+
with gr.Column():
|
| 777 |
+
param_choice = gr.Dropdown(
|
| 778 |
+
label="Param Choice",
|
| 779 |
+
choices=["AutoTrain"],
|
| 780 |
+
value="AutoTrain",
|
| 781 |
+
visible=True,
|
| 782 |
+
interactive=True,
|
| 783 |
+
)
|
| 784 |
+
training_params_txt = gr.Text(value="[]", visible=False, interactive=False)
|
| 785 |
+
training_params_md = gr.DataFrame(visible=False, interactive=False)
|
| 786 |
+
|
| 787 |
+
final_output = gr.Markdown(value="", visible=True, interactive=False)
|
| 788 |
+
hyperparameters = [
|
| 789 |
+
hub_model,
|
| 790 |
+
num_models,
|
| 791 |
+
source_language,
|
| 792 |
+
target_language,
|
| 793 |
+
learning_rate,
|
| 794 |
+
batch_size,
|
| 795 |
+
num_epochs,
|
| 796 |
+
gradient_accumulation_steps,
|
| 797 |
+
lora_r,
|
| 798 |
+
lora_alpha,
|
| 799 |
+
lora_dropout,
|
| 800 |
+
optimizer,
|
| 801 |
+
scheduler,
|
| 802 |
+
percentage_warmup_steps,
|
| 803 |
+
weight_decay,
|
| 804 |
+
db_num_steps,
|
| 805 |
+
image_size,
|
| 806 |
+
]
|
| 807 |
+
|
| 808 |
+
def _update_params(params_data):
|
| 809 |
+
_task = params_data[task_type]
|
| 810 |
+
_task = APP_TASKS_MAPPING[_task]
|
| 811 |
+
params = Params(
|
| 812 |
+
task=_task,
|
| 813 |
+
param_choice="autotrain" if params_data[param_choice] == "AutoTrain" else "manual",
|
| 814 |
+
model_choice="autotrain" if params_data[model_choice] == "AutoTrain" else "hub_model",
|
| 815 |
+
)
|
| 816 |
+
params = params.get()
|
| 817 |
+
visible_params = []
|
| 818 |
+
for param in hyperparameters:
|
| 819 |
+
if param.elem_id in params.keys():
|
| 820 |
+
visible_params.append(param.elem_id)
|
| 821 |
+
op = [h.update(visible=h.elem_id in visible_params) for h in hyperparameters]
|
| 822 |
+
op.append(add_job_button.update(visible=True))
|
| 823 |
+
op.append(training_params_md.update(visible=False))
|
| 824 |
+
op.append(training_params_txt.update(value="[]"))
|
| 825 |
+
return op
|
| 826 |
+
|
| 827 |
+
autotrain_backend.change(
|
| 828 |
+
_project_type_update,
|
| 829 |
+
inputs=[project_type, task_type, autotrain_backend],
|
| 830 |
+
outputs=[task_type, model_choice, param_choice, hub_model],
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
project_type.change(
|
| 834 |
+
_project_type_update,
|
| 835 |
+
inputs=[project_type, task_type, autotrain_backend],
|
| 836 |
+
outputs=[task_type, model_choice, param_choice, hub_model],
|
| 837 |
+
)
|
| 838 |
+
task_type.change(
|
| 839 |
+
_task_type_update,
|
| 840 |
+
inputs=[task_type, autotrain_backend],
|
| 841 |
+
outputs=[model_choice, param_choice, hub_model],
|
| 842 |
+
)
|
| 843 |
+
model_choice.change(
|
| 844 |
+
_update_param_choice,
|
| 845 |
+
inputs=[model_choice, autotrain_backend],
|
| 846 |
+
outputs=param_choice,
|
| 847 |
+
).then(
|
| 848 |
+
_update_hub_model_choices,
|
| 849 |
+
inputs=[task_type, model_choice],
|
| 850 |
+
outputs=hub_model,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
param_choice.change(
|
| 854 |
+
_update_params,
|
| 855 |
+
inputs=set([task_type, param_choice, model_choice] + hyperparameters + [add_job_button]),
|
| 856 |
+
outputs=hyperparameters + [add_job_button, training_params_md, training_params_txt],
|
| 857 |
+
)
|
| 858 |
+
task_type.change(
|
| 859 |
+
_update_params,
|
| 860 |
+
inputs=set([task_type, param_choice, model_choice] + hyperparameters + [add_job_button]),
|
| 861 |
+
outputs=hyperparameters + [add_job_button, training_params_md, training_params_txt],
|
| 862 |
+
)
|
| 863 |
+
model_choice.change(
|
| 864 |
+
_update_params,
|
| 865 |
+
inputs=set([task_type, param_choice, model_choice] + hyperparameters + [add_job_button]),
|
| 866 |
+
outputs=hyperparameters + [add_job_button, training_params_md, training_params_txt],
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
def _add_job(params_data):
|
| 870 |
+
_task = params_data[task_type]
|
| 871 |
+
_task = APP_TASKS_MAPPING[_task]
|
| 872 |
+
_param_choice = "autotrain" if params_data[param_choice] == "AutoTrain" else "manual"
|
| 873 |
+
_model_choice = "autotrain" if params_data[model_choice] == "AutoTrain" else "hub_model"
|
| 874 |
+
if _model_choice == "hub_model" and params_data[hub_model] is None:
|
| 875 |
+
logger.error("Hub model is None")
|
| 876 |
+
return
|
| 877 |
+
_training_params = {}
|
| 878 |
+
params = Params(task=_task, param_choice=_param_choice, model_choice=_model_choice)
|
| 879 |
+
params = params.get()
|
| 880 |
+
for _param in hyperparameters:
|
| 881 |
+
if _param.elem_id in params.keys():
|
| 882 |
+
_training_params[_param.elem_id] = params_data[_param]
|
| 883 |
+
_training_params_md = json.loads(params_data[training_params_txt])
|
| 884 |
+
if _param_choice == "autotrain":
|
| 885 |
+
if len(_training_params_md) > 0:
|
| 886 |
+
_training_params_md[0] = _training_params
|
| 887 |
+
_training_params_md = _training_params_md[:1]
|
| 888 |
+
else:
|
| 889 |
+
_training_params_md.append(_training_params)
|
| 890 |
+
else:
|
| 891 |
+
_training_params_md.append(_training_params)
|
| 892 |
+
params_df = pd.DataFrame(_training_params_md)
|
| 893 |
+
# remove hub_model column
|
| 894 |
+
if "hub_model" in params_df.columns:
|
| 895 |
+
params_df = params_df.drop(columns=["hub_model"])
|
| 896 |
+
return [
|
| 897 |
+
gr.DataFrame.update(value=params_df, visible=True),
|
| 898 |
+
gr.Textbox.update(value=json.dumps(_training_params_md), visible=False),
|
| 899 |
+
]
|
| 900 |
+
|
| 901 |
+
add_job_button.click(
|
| 902 |
+
_add_job,
|
| 903 |
+
inputs=set(
|
| 904 |
+
[task_type, param_choice, model_choice] + hyperparameters + [training_params_md, training_params_txt]
|
| 905 |
+
),
|
| 906 |
+
outputs=[training_params_md, training_params_txt],
|
| 907 |
+
)
|
| 908 |
+
col_map_components = [
|
| 909 |
+
col_map_text,
|
| 910 |
+
col_map_target,
|
| 911 |
+
concept_token,
|
| 912 |
+
]
|
| 913 |
+
training_data.change(
|
| 914 |
+
_update_col_map,
|
| 915 |
+
inputs=[training_data, task_type],
|
| 916 |
+
outputs=col_map_components,
|
| 917 |
+
)
|
| 918 |
+
task_type.change(
|
| 919 |
+
_update_col_map,
|
| 920 |
+
inputs=[training_data, task_type],
|
| 921 |
+
outputs=col_map_components,
|
| 922 |
+
)
|
| 923 |
+
estimate_costs_inputs = [
|
| 924 |
+
training_data,
|
| 925 |
+
validation_data,
|
| 926 |
+
task_type,
|
| 927 |
+
user_token,
|
| 928 |
+
autotrain_username,
|
| 929 |
+
training_params_txt,
|
| 930 |
+
autotrain_backend,
|
| 931 |
+
]
|
| 932 |
+
estimate_costs_outputs = [estimated_costs_md, estimated_costs_num]
|
| 933 |
+
training_data.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
|
| 934 |
+
validation_data.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
|
| 935 |
+
training_params_txt.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
|
| 936 |
+
task_type.change(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
|
| 937 |
+
add_job_button.click(_estimate_costs, inputs=estimate_costs_inputs, outputs=estimate_costs_outputs)
|
| 938 |
+
|
| 939 |
+
create_project_button.click(disable_create_project_button, None, create_project_button).then(
|
| 940 |
+
_create_project,
|
| 941 |
+
inputs=[
|
| 942 |
+
autotrain_username,
|
| 943 |
+
valid_can_pay,
|
| 944 |
+
project_name,
|
| 945 |
+
user_token,
|
| 946 |
+
task_type,
|
| 947 |
+
training_data,
|
| 948 |
+
validation_data,
|
| 949 |
+
col_map_text,
|
| 950 |
+
col_map_target,
|
| 951 |
+
concept_token,
|
| 952 |
+
training_params_txt,
|
| 953 |
+
hub_model,
|
| 954 |
+
estimated_costs_num,
|
| 955 |
+
autotrain_backend,
|
| 956 |
+
],
|
| 957 |
+
outputs=final_output,
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
demo.load(
|
| 961 |
+
_update_project_name,
|
| 962 |
+
outputs=[project_name, create_project_button],
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
return demo
|
autotrain-advanced/src/autotrain/cli/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseAutoTrainCommand(ABC):
|
| 6 |
+
@staticmethod
|
| 7 |
+
@abstractmethod
|
| 8 |
+
def register_subcommand(parser: ArgumentParser):
|
| 9 |
+
raise NotImplementedError()
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def run(self):
|
| 13 |
+
raise NotImplementedError()
|
autotrain-advanced/src/autotrain/cli/accelerated_autotrain.py
ADDED
|
File without changes
|
autotrain-advanced/src/autotrain/cli/autotrain.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from .. import __version__
|
| 4 |
+
from .run_app import RunAutoTrainAppCommand
|
| 5 |
+
from .run_dreambooth import RunAutoTrainDreamboothCommand
|
| 6 |
+
from .run_llm import RunAutoTrainLLMCommand
|
| 7 |
+
from .run_setup import RunSetupCommand
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
parser = argparse.ArgumentParser(
|
| 12 |
+
"AutoTrain advanced CLI",
|
| 13 |
+
usage="autotrain <command> [<args>]",
|
| 14 |
+
epilog="For more information about a command, run: `autotrain <command> --help`",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument("--version", "-v", help="Display AutoTrain version", action="store_true")
|
| 17 |
+
commands_parser = parser.add_subparsers(help="commands")
|
| 18 |
+
|
| 19 |
+
# Register commands
|
| 20 |
+
RunAutoTrainAppCommand.register_subcommand(commands_parser)
|
| 21 |
+
RunAutoTrainLLMCommand.register_subcommand(commands_parser)
|
| 22 |
+
RunSetupCommand.register_subcommand(commands_parser)
|
| 23 |
+
RunAutoTrainDreamboothCommand.register_subcommand(commands_parser)
|
| 24 |
+
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
if args.version:
|
| 28 |
+
print(__version__)
|
| 29 |
+
exit(0)
|
| 30 |
+
|
| 31 |
+
if not hasattr(args, "func"):
|
| 32 |
+
parser.print_help()
|
| 33 |
+
exit(1)
|
| 34 |
+
|
| 35 |
+
command = args.func(args)
|
| 36 |
+
command.run()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
autotrain-advanced/src/autotrain/cli/run_app.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
|
| 3 |
+
from . import BaseAutoTrainCommand
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def run_app_command_factory(args):
|
| 7 |
+
return RunAutoTrainAppCommand(
|
| 8 |
+
args.port,
|
| 9 |
+
args.host,
|
| 10 |
+
args.task,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RunAutoTrainAppCommand(BaseAutoTrainCommand):
|
| 15 |
+
@staticmethod
|
| 16 |
+
def register_subcommand(parser: ArgumentParser):
|
| 17 |
+
run_app_parser = parser.add_parser(
|
| 18 |
+
"app",
|
| 19 |
+
description="✨ Run AutoTrain app",
|
| 20 |
+
)
|
| 21 |
+
run_app_parser.add_argument(
|
| 22 |
+
"--port",
|
| 23 |
+
type=int,
|
| 24 |
+
default=7860,
|
| 25 |
+
help="Port to run the app on",
|
| 26 |
+
required=False,
|
| 27 |
+
)
|
| 28 |
+
run_app_parser.add_argument(
|
| 29 |
+
"--host",
|
| 30 |
+
type=str,
|
| 31 |
+
default="127.0.0.1",
|
| 32 |
+
help="Host to run the app on",
|
| 33 |
+
required=False,
|
| 34 |
+
)
|
| 35 |
+
run_app_parser.add_argument(
|
| 36 |
+
"--task",
|
| 37 |
+
type=str,
|
| 38 |
+
required=False,
|
| 39 |
+
help="Task to run",
|
| 40 |
+
)
|
| 41 |
+
run_app_parser.set_defaults(func=run_app_command_factory)
|
| 42 |
+
|
| 43 |
+
def __init__(self, port, host, task):
|
| 44 |
+
self.port = port
|
| 45 |
+
self.host = host
|
| 46 |
+
self.task = task
|
| 47 |
+
|
| 48 |
+
def run(self):
|
| 49 |
+
if self.task == "dreambooth":
|
| 50 |
+
from ..dreambooth_app import main
|
| 51 |
+
else:
|
| 52 |
+
from ..app import main
|
| 53 |
+
|
| 54 |
+
demo = main()
|
| 55 |
+
demo.queue(concurrency_count=10).launch()
|
autotrain-advanced/src/autotrain/cli/run_dreambooth.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
|
| 5 |
+
from loguru import logger
|
| 6 |
+
|
| 7 |
+
from autotrain.cli import BaseAutoTrainCommand
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from autotrain.trainers.dreambooth import train as train_dreambooth
|
| 12 |
+
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
|
| 13 |
+
from autotrain.trainers.dreambooth.utils import VALID_IMAGE_EXTENSIONS, XL_MODELS
|
| 14 |
+
except ImportError:
|
| 15 |
+
logger.warning(
|
| 16 |
+
"❌ Some DreamBooth components are missing! Please run `autotrain setup` to install it. Ignore this warning if you are not using DreamBooth or running `autotrain setup` already."
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def count_images(directory):
|
| 21 |
+
files_grabbed = []
|
| 22 |
+
for files in VALID_IMAGE_EXTENSIONS:
|
| 23 |
+
files_grabbed.extend(glob.glob(os.path.join(directory, "*" + files)))
|
| 24 |
+
return len(files_grabbed)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def run_dreambooth_command_factory(args):
|
| 28 |
+
return RunAutoTrainDreamboothCommand(args)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RunAutoTrainDreamboothCommand(BaseAutoTrainCommand):
|
| 32 |
+
@staticmethod
|
| 33 |
+
def register_subcommand(parser: ArgumentParser):
|
| 34 |
+
arg_list = [
|
| 35 |
+
{
|
| 36 |
+
"arg": "--model",
|
| 37 |
+
"help": "Model to use for training",
|
| 38 |
+
"required": True,
|
| 39 |
+
"type": str,
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"arg": "--revision",
|
| 43 |
+
"help": "Model revision to use for training",
|
| 44 |
+
"required": False,
|
| 45 |
+
"type": str,
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"arg": "--tokenizer",
|
| 49 |
+
"help": "Tokenizer to use for training",
|
| 50 |
+
"required": False,
|
| 51 |
+
"type": str,
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"arg": "--image-path",
|
| 55 |
+
"help": "Path to the images",
|
| 56 |
+
"required": True,
|
| 57 |
+
"type": str,
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"arg": "--class-image-path",
|
| 61 |
+
"help": "Path to the class images",
|
| 62 |
+
"required": False,
|
| 63 |
+
"type": str,
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"arg": "--prompt",
|
| 67 |
+
"help": "Instance prompt",
|
| 68 |
+
"required": True,
|
| 69 |
+
"type": str,
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"arg": "--class-prompt",
|
| 73 |
+
"help": "Class prompt",
|
| 74 |
+
"required": False,
|
| 75 |
+
"type": str,
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"arg": "--num-class-images",
|
| 79 |
+
"help": "Number of class images",
|
| 80 |
+
"required": False,
|
| 81 |
+
"default": 100,
|
| 82 |
+
"type": int,
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"arg": "--class-labels-conditioning",
|
| 86 |
+
"help": "Class labels conditioning",
|
| 87 |
+
"required": False,
|
| 88 |
+
"type": str,
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"arg": "--prior-preservation",
|
| 92 |
+
"help": "With prior preservation",
|
| 93 |
+
"required": False,
|
| 94 |
+
"action": "store_true",
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"arg": "--prior-loss-weight",
|
| 98 |
+
"help": "Prior loss weight",
|
| 99 |
+
"required": False,
|
| 100 |
+
"default": 1.0,
|
| 101 |
+
"type": float,
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"arg": "--output",
|
| 105 |
+
"help": "Output directory",
|
| 106 |
+
"required": True,
|
| 107 |
+
"type": str,
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"arg": "--seed",
|
| 111 |
+
"help": "Seed",
|
| 112 |
+
"required": False,
|
| 113 |
+
"default": 42,
|
| 114 |
+
"type": int,
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"arg": "--resolution",
|
| 118 |
+
"help": "Resolution",
|
| 119 |
+
"required": True,
|
| 120 |
+
"type": int,
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"arg": "--center-crop",
|
| 124 |
+
"help": "Center crop",
|
| 125 |
+
"required": False,
|
| 126 |
+
"action": "store_true",
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"arg": "--train-text-encoder",
|
| 130 |
+
"help": "Train text encoder",
|
| 131 |
+
"required": False,
|
| 132 |
+
"action": "store_true",
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"arg": "--batch-size",
|
| 136 |
+
"help": "Train batch size",
|
| 137 |
+
"required": False,
|
| 138 |
+
"default": 4,
|
| 139 |
+
"type": int,
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"arg": "--sample-batch-size",
|
| 143 |
+
"help": "Sample batch size",
|
| 144 |
+
"required": False,
|
| 145 |
+
"default": 4,
|
| 146 |
+
"type": int,
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"arg": "--epochs",
|
| 150 |
+
"help": "Number of training epochs",
|
| 151 |
+
"required": False,
|
| 152 |
+
"default": 1,
|
| 153 |
+
"type": int,
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"arg": "--num-steps",
|
| 157 |
+
"help": "Max train steps",
|
| 158 |
+
"required": False,
|
| 159 |
+
"type": int,
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"arg": "--checkpointing-steps",
|
| 163 |
+
"help": "Checkpointing steps",
|
| 164 |
+
"required": False,
|
| 165 |
+
"default": 100000,
|
| 166 |
+
"type": int,
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"arg": "--resume-from-checkpoint",
|
| 170 |
+
"help": "Resume from checkpoint",
|
| 171 |
+
"required": False,
|
| 172 |
+
"type": str,
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"arg": "--gradient-accumulation",
|
| 176 |
+
"help": "Gradient accumulation steps",
|
| 177 |
+
"required": False,
|
| 178 |
+
"default": 1,
|
| 179 |
+
"type": int,
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"arg": "--gradient-checkpointing",
|
| 183 |
+
"help": "Gradient checkpointing",
|
| 184 |
+
"required": False,
|
| 185 |
+
"action": "store_true",
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"arg": "--lr",
|
| 189 |
+
"help": "Learning rate",
|
| 190 |
+
"required": False,
|
| 191 |
+
"default": 5e-4,
|
| 192 |
+
"type": float,
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"arg": "--scale-lr",
|
| 196 |
+
"help": "Scale learning rate",
|
| 197 |
+
"required": False,
|
| 198 |
+
"action": "store_true",
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"arg": "--scheduler",
|
| 202 |
+
"help": "Learning rate scheduler",
|
| 203 |
+
"required": False,
|
| 204 |
+
"default": "constant",
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"arg": "--warmup-steps",
|
| 208 |
+
"help": "Learning rate warmup steps",
|
| 209 |
+
"required": False,
|
| 210 |
+
"default": 0,
|
| 211 |
+
"type": int,
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"arg": "--num-cycles",
|
| 215 |
+
"help": "Learning rate num cycles",
|
| 216 |
+
"required": False,
|
| 217 |
+
"default": 1,
|
| 218 |
+
"type": int,
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"arg": "--lr-power",
|
| 222 |
+
"help": "Learning rate power",
|
| 223 |
+
"required": False,
|
| 224 |
+
"default": 1.0,
|
| 225 |
+
"type": float,
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"arg": "--dataloader-num-workers",
|
| 229 |
+
"help": "Dataloader num workers",
|
| 230 |
+
"required": False,
|
| 231 |
+
"default": 0,
|
| 232 |
+
"type": int,
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"arg": "--use-8bit-adam",
|
| 236 |
+
"help": "Use 8bit adam",
|
| 237 |
+
"required": False,
|
| 238 |
+
"action": "store_true",
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"arg": "--adam-beta1",
|
| 242 |
+
"help": "Adam beta 1",
|
| 243 |
+
"required": False,
|
| 244 |
+
"default": 0.9,
|
| 245 |
+
"type": float,
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"arg": "--adam-beta2",
|
| 249 |
+
"help": "Adam beta 2",
|
| 250 |
+
"required": False,
|
| 251 |
+
"default": 0.999,
|
| 252 |
+
"type": float,
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"arg": "--adam-weight-decay",
|
| 256 |
+
"help": "Adam weight decay",
|
| 257 |
+
"required": False,
|
| 258 |
+
"default": 1e-2,
|
| 259 |
+
"type": float,
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"arg": "--adam-epsilon",
|
| 263 |
+
"help": "Adam epsilon",
|
| 264 |
+
"required": False,
|
| 265 |
+
"default": 1e-8,
|
| 266 |
+
"type": float,
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"arg": "--max-grad-norm",
|
| 270 |
+
"help": "Max grad norm",
|
| 271 |
+
"required": False,
|
| 272 |
+
"default": 1.0,
|
| 273 |
+
"type": float,
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"arg": "--allow-tf32",
|
| 277 |
+
"help": "Allow TF32",
|
| 278 |
+
"required": False,
|
| 279 |
+
"action": "store_true",
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"arg": "--prior-generation-precision",
|
| 283 |
+
"help": "Prior generation precision",
|
| 284 |
+
"required": False,
|
| 285 |
+
"type": str,
|
| 286 |
+
},
|
| 287 |
+
{
|
| 288 |
+
"arg": "--local-rank",
|
| 289 |
+
"help": "Local rank",
|
| 290 |
+
"required": False,
|
| 291 |
+
"default": -1,
|
| 292 |
+
"type": int,
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"arg": "--xformers",
|
| 296 |
+
"help": "Enable xformers memory efficient attention",
|
| 297 |
+
"required": False,
|
| 298 |
+
"action": "store_true",
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"arg": "--pre-compute-text-embeddings",
|
| 302 |
+
"help": "Pre compute text embeddings",
|
| 303 |
+
"required": False,
|
| 304 |
+
"action": "store_true",
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"arg": "--tokenizer-max-length",
|
| 308 |
+
"help": "Tokenizer max length",
|
| 309 |
+
"required": False,
|
| 310 |
+
"type": int,
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"arg": "--text-encoder-use-attention-mask",
|
| 314 |
+
"help": "Text encoder use attention mask",
|
| 315 |
+
"required": False,
|
| 316 |
+
"action": "store_true",
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"arg": "--rank",
|
| 320 |
+
"help": "Rank",
|
| 321 |
+
"required": False,
|
| 322 |
+
"default": 4,
|
| 323 |
+
"type": int,
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"arg": "--xl",
|
| 327 |
+
"help": "XL",
|
| 328 |
+
"required": False,
|
| 329 |
+
"action": "store_true",
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"arg": "--fp16",
|
| 333 |
+
"help": "FP16",
|
| 334 |
+
"required": False,
|
| 335 |
+
"action": "store_true",
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"arg": "--bf16",
|
| 339 |
+
"help": "BF16",
|
| 340 |
+
"required": False,
|
| 341 |
+
"action": "store_true",
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"arg": "--hub-token",
|
| 345 |
+
"help": "Hub token",
|
| 346 |
+
"required": False,
|
| 347 |
+
"type": str,
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"arg": "--hub-model-id",
|
| 351 |
+
"help": "Hub model id",
|
| 352 |
+
"required": False,
|
| 353 |
+
"type": str,
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"arg": "--push-to-hub",
|
| 357 |
+
"help": "Push to hub",
|
| 358 |
+
"required": False,
|
| 359 |
+
"action": "store_true",
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"arg": "--validation-prompt",
|
| 363 |
+
"help": "Validation prompt",
|
| 364 |
+
"required": False,
|
| 365 |
+
"type": str,
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"arg": "--num-validation-images",
|
| 369 |
+
"help": "Number of validation images",
|
| 370 |
+
"required": False,
|
| 371 |
+
"default": 4,
|
| 372 |
+
"type": int,
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"arg": "--validation-epochs",
|
| 376 |
+
"help": "Validation epochs",
|
| 377 |
+
"required": False,
|
| 378 |
+
"default": 50,
|
| 379 |
+
"type": int,
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"arg": "--checkpoints-total-limit",
|
| 383 |
+
"help": "Checkpoints total limit",
|
| 384 |
+
"required": False,
|
| 385 |
+
"type": int,
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"arg": "--validation-images",
|
| 389 |
+
"help": "Validation images",
|
| 390 |
+
"required": False,
|
| 391 |
+
"type": str,
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"arg": "--logging",
|
| 395 |
+
"help": "Logging using tensorboard",
|
| 396 |
+
"required": False,
|
| 397 |
+
"action": "store_true",
|
| 398 |
+
},
|
| 399 |
+
]
|
| 400 |
+
|
| 401 |
+
run_dreambooth_parser = parser.add_parser("dreambooth", description="✨ Run AutoTrain DreamBooth Training")
|
| 402 |
+
for arg in arg_list:
|
| 403 |
+
if "action" in arg:
|
| 404 |
+
run_dreambooth_parser.add_argument(
|
| 405 |
+
arg["arg"],
|
| 406 |
+
help=arg["help"],
|
| 407 |
+
required=arg.get("required", False),
|
| 408 |
+
action=arg.get("action"),
|
| 409 |
+
default=arg.get("default"),
|
| 410 |
+
)
|
| 411 |
+
else:
|
| 412 |
+
run_dreambooth_parser.add_argument(
|
| 413 |
+
arg["arg"],
|
| 414 |
+
help=arg["help"],
|
| 415 |
+
required=arg.get("required", False),
|
| 416 |
+
type=arg.get("type"),
|
| 417 |
+
default=arg.get("default"),
|
| 418 |
+
)
|
| 419 |
+
run_dreambooth_parser.set_defaults(func=run_dreambooth_command_factory)
|
| 420 |
+
|
| 421 |
+
def __init__(self, args):
|
| 422 |
+
self.args = args
|
| 423 |
+
logger.info(self.args)
|
| 424 |
+
|
| 425 |
+
store_true_arg_names = [
|
| 426 |
+
"center_crop",
|
| 427 |
+
"train_text_encoder",
|
| 428 |
+
"gradient_checkpointing",
|
| 429 |
+
"scale_lr",
|
| 430 |
+
"use_8bit_adam",
|
| 431 |
+
"allow_tf32",
|
| 432 |
+
"xformers",
|
| 433 |
+
"pre_compute_text_embeddings",
|
| 434 |
+
"text_encoder_use_attention_mask",
|
| 435 |
+
"xl",
|
| 436 |
+
"fp16",
|
| 437 |
+
"bf16",
|
| 438 |
+
"push_to_hub",
|
| 439 |
+
"logging",
|
| 440 |
+
"prior_preservation",
|
| 441 |
+
]
|
| 442 |
+
|
| 443 |
+
for arg_name in store_true_arg_names:
|
| 444 |
+
if getattr(self.args, arg_name) is None:
|
| 445 |
+
setattr(self.args, arg_name, False)
|
| 446 |
+
|
| 447 |
+
if self.args.fp16 and self.args.bf16:
|
| 448 |
+
raise ValueError("❌ Please choose either FP16 or BF16")
|
| 449 |
+
|
| 450 |
+
# check if self.args.image_path is a directory with images
|
| 451 |
+
if not os.path.isdir(self.args.image_path):
|
| 452 |
+
raise ValueError("❌ Please specify a valid image directory")
|
| 453 |
+
|
| 454 |
+
# count the number of images in the directory. valid images are .jpg, .jpeg, .png
|
| 455 |
+
num_images = count_images(self.args.image_path)
|
| 456 |
+
if num_images == 0:
|
| 457 |
+
raise ValueError("❌ Please specify a valid image directory")
|
| 458 |
+
|
| 459 |
+
if self.args.push_to_hub:
|
| 460 |
+
if self.args.hub_model_id is None:
|
| 461 |
+
raise ValueError("❌ Please specify a hub model id")
|
| 462 |
+
|
| 463 |
+
if self.args.model in XL_MODELS:
|
| 464 |
+
self.args.xl = True
|
| 465 |
+
|
| 466 |
+
def run(self):
|
| 467 |
+
logger.info("Running DreamBooth Training")
|
| 468 |
+
params = DreamBoothTrainingParams(**vars(self.args))
|
| 469 |
+
train_dreambooth(params)
|
autotrain-advanced/src/autotrain/cli/run_llm.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
|
| 3 |
+
from loguru import logger
|
| 4 |
+
|
| 5 |
+
from autotrain.infer.text_generation import TextGenerationInference
|
| 6 |
+
|
| 7 |
+
from ..trainers.clm import train as train_llm
|
| 8 |
+
from ..trainers.utils import LLMTrainingParams
|
| 9 |
+
from . import BaseAutoTrainCommand
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def run_llm_command_factory(args):
|
| 13 |
+
return RunAutoTrainLLMCommand(
|
| 14 |
+
args.train,
|
| 15 |
+
args.deploy,
|
| 16 |
+
args.inference,
|
| 17 |
+
args.data_path,
|
| 18 |
+
args.train_split,
|
| 19 |
+
args.valid_split,
|
| 20 |
+
args.text_column,
|
| 21 |
+
args.model,
|
| 22 |
+
args.learning_rate,
|
| 23 |
+
args.num_train_epochs,
|
| 24 |
+
args.train_batch_size,
|
| 25 |
+
args.eval_batch_size,
|
| 26 |
+
args.warmup_ratio,
|
| 27 |
+
args.gradient_accumulation_steps,
|
| 28 |
+
args.optimizer,
|
| 29 |
+
args.scheduler,
|
| 30 |
+
args.weight_decay,
|
| 31 |
+
args.max_grad_norm,
|
| 32 |
+
args.seed,
|
| 33 |
+
args.add_eos_token,
|
| 34 |
+
args.block_size,
|
| 35 |
+
args.use_peft,
|
| 36 |
+
args.lora_r,
|
| 37 |
+
args.lora_alpha,
|
| 38 |
+
args.lora_dropout,
|
| 39 |
+
args.training_type,
|
| 40 |
+
args.train_on_inputs,
|
| 41 |
+
args.logging_steps,
|
| 42 |
+
args.project_name,
|
| 43 |
+
args.evaluation_strategy,
|
| 44 |
+
args.save_total_limit,
|
| 45 |
+
args.save_strategy,
|
| 46 |
+
args.auto_find_batch_size,
|
| 47 |
+
args.fp16,
|
| 48 |
+
args.push_to_hub,
|
| 49 |
+
args.use_int8,
|
| 50 |
+
args.model_max_length,
|
| 51 |
+
args.repo_id,
|
| 52 |
+
args.use_int4,
|
| 53 |
+
args.trainer,
|
| 54 |
+
args.target_modules,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RunAutoTrainLLMCommand(BaseAutoTrainCommand):
|
| 59 |
+
@staticmethod
|
| 60 |
+
def register_subcommand(parser: ArgumentParser):
|
| 61 |
+
run_llm_parser = parser.add_parser(
|
| 62 |
+
"llm",
|
| 63 |
+
description="✨ Run AutoTrain LLM training/inference/deployment",
|
| 64 |
+
)
|
| 65 |
+
run_llm_parser.add_argument(
|
| 66 |
+
"--train",
|
| 67 |
+
help="Train the model",
|
| 68 |
+
required=False,
|
| 69 |
+
action="store_true",
|
| 70 |
+
)
|
| 71 |
+
run_llm_parser.add_argument(
|
| 72 |
+
"--deploy",
|
| 73 |
+
help="Deploy the model",
|
| 74 |
+
required=False,
|
| 75 |
+
action="store_true",
|
| 76 |
+
)
|
| 77 |
+
run_llm_parser.add_argument(
|
| 78 |
+
"--inference",
|
| 79 |
+
help="Run inference",
|
| 80 |
+
required=False,
|
| 81 |
+
action="store_true",
|
| 82 |
+
)
|
| 83 |
+
run_llm_parser.add_argument(
|
| 84 |
+
"--data_path",
|
| 85 |
+
help="Train dataset to use",
|
| 86 |
+
required=False,
|
| 87 |
+
type=str,
|
| 88 |
+
)
|
| 89 |
+
run_llm_parser.add_argument(
|
| 90 |
+
"--train_split",
|
| 91 |
+
help="Test dataset split to use",
|
| 92 |
+
required=False,
|
| 93 |
+
type=str,
|
| 94 |
+
default="train",
|
| 95 |
+
)
|
| 96 |
+
run_llm_parser.add_argument(
|
| 97 |
+
"--valid_split",
|
| 98 |
+
help="Validation dataset split to use",
|
| 99 |
+
required=False,
|
| 100 |
+
type=str,
|
| 101 |
+
default=None,
|
| 102 |
+
)
|
| 103 |
+
run_llm_parser.add_argument(
|
| 104 |
+
"--text_column",
|
| 105 |
+
help="Text column to use",
|
| 106 |
+
required=False,
|
| 107 |
+
type=str,
|
| 108 |
+
default="text",
|
| 109 |
+
)
|
| 110 |
+
run_llm_parser.add_argument(
|
| 111 |
+
"--model",
|
| 112 |
+
help="Model to use",
|
| 113 |
+
required=False,
|
| 114 |
+
type=str,
|
| 115 |
+
)
|
| 116 |
+
run_llm_parser.add_argument(
|
| 117 |
+
"--learning_rate",
|
| 118 |
+
help="Learning rate to use",
|
| 119 |
+
required=False,
|
| 120 |
+
type=float,
|
| 121 |
+
default=3e-5,
|
| 122 |
+
)
|
| 123 |
+
run_llm_parser.add_argument(
|
| 124 |
+
"--num_train_epochs",
|
| 125 |
+
help="Number of training epochs to use",
|
| 126 |
+
required=False,
|
| 127 |
+
type=int,
|
| 128 |
+
default=1,
|
| 129 |
+
)
|
| 130 |
+
run_llm_parser.add_argument(
|
| 131 |
+
"--train_batch_size",
|
| 132 |
+
help="Training batch size to use",
|
| 133 |
+
required=False,
|
| 134 |
+
type=int,
|
| 135 |
+
default=2,
|
| 136 |
+
)
|
| 137 |
+
run_llm_parser.add_argument(
|
| 138 |
+
"--eval_batch_size",
|
| 139 |
+
help="Evaluation batch size to use",
|
| 140 |
+
required=False,
|
| 141 |
+
type=int,
|
| 142 |
+
default=4,
|
| 143 |
+
)
|
| 144 |
+
run_llm_parser.add_argument(
|
| 145 |
+
"--warmup_ratio",
|
| 146 |
+
help="Warmup proportion to use",
|
| 147 |
+
required=False,
|
| 148 |
+
type=float,
|
| 149 |
+
default=0.1,
|
| 150 |
+
)
|
| 151 |
+
run_llm_parser.add_argument(
|
| 152 |
+
"--gradient_accumulation_steps",
|
| 153 |
+
help="Gradient accumulation steps to use",
|
| 154 |
+
required=False,
|
| 155 |
+
type=int,
|
| 156 |
+
default=1,
|
| 157 |
+
)
|
| 158 |
+
run_llm_parser.add_argument(
|
| 159 |
+
"--optimizer",
|
| 160 |
+
help="Optimizer to use",
|
| 161 |
+
required=False,
|
| 162 |
+
type=str,
|
| 163 |
+
default="adamw_torch",
|
| 164 |
+
)
|
| 165 |
+
run_llm_parser.add_argument(
|
| 166 |
+
"--scheduler",
|
| 167 |
+
help="Scheduler to use",
|
| 168 |
+
required=False,
|
| 169 |
+
type=str,
|
| 170 |
+
default="linear",
|
| 171 |
+
)
|
| 172 |
+
run_llm_parser.add_argument(
|
| 173 |
+
"--weight_decay",
|
| 174 |
+
help="Weight decay to use",
|
| 175 |
+
required=False,
|
| 176 |
+
type=float,
|
| 177 |
+
default=0.0,
|
| 178 |
+
)
|
| 179 |
+
run_llm_parser.add_argument(
|
| 180 |
+
"--max_grad_norm",
|
| 181 |
+
help="Max gradient norm to use",
|
| 182 |
+
required=False,
|
| 183 |
+
type=float,
|
| 184 |
+
default=1.0,
|
| 185 |
+
)
|
| 186 |
+
run_llm_parser.add_argument(
|
| 187 |
+
"--seed",
|
| 188 |
+
help="Seed to use",
|
| 189 |
+
required=False,
|
| 190 |
+
type=int,
|
| 191 |
+
default=42,
|
| 192 |
+
)
|
| 193 |
+
run_llm_parser.add_argument(
|
| 194 |
+
"--add_eos_token",
|
| 195 |
+
help="Add EOS token to use",
|
| 196 |
+
required=False,
|
| 197 |
+
action="store_true",
|
| 198 |
+
)
|
| 199 |
+
run_llm_parser.add_argument(
|
| 200 |
+
"--block_size",
|
| 201 |
+
help="Block size to use",
|
| 202 |
+
required=False,
|
| 203 |
+
type=int,
|
| 204 |
+
default=-1,
|
| 205 |
+
)
|
| 206 |
+
run_llm_parser.add_argument(
|
| 207 |
+
"--use_peft",
|
| 208 |
+
help="Use PEFT to use",
|
| 209 |
+
required=False,
|
| 210 |
+
action="store_true",
|
| 211 |
+
)
|
| 212 |
+
run_llm_parser.add_argument(
|
| 213 |
+
"--lora_r",
|
| 214 |
+
help="Lora r to use",
|
| 215 |
+
required=False,
|
| 216 |
+
type=int,
|
| 217 |
+
default=16,
|
| 218 |
+
)
|
| 219 |
+
run_llm_parser.add_argument(
|
| 220 |
+
"--lora_alpha",
|
| 221 |
+
help="Lora alpha to use",
|
| 222 |
+
required=False,
|
| 223 |
+
type=int,
|
| 224 |
+
default=32,
|
| 225 |
+
)
|
| 226 |
+
run_llm_parser.add_argument(
|
| 227 |
+
"--lora_dropout",
|
| 228 |
+
help="Lora dropout to use",
|
| 229 |
+
required=False,
|
| 230 |
+
type=float,
|
| 231 |
+
default=0.05,
|
| 232 |
+
)
|
| 233 |
+
run_llm_parser.add_argument(
|
| 234 |
+
"--training_type",
|
| 235 |
+
help="Training type to use",
|
| 236 |
+
required=False,
|
| 237 |
+
type=str,
|
| 238 |
+
default="generic",
|
| 239 |
+
)
|
| 240 |
+
run_llm_parser.add_argument(
|
| 241 |
+
"--train_on_inputs",
|
| 242 |
+
help="Train on inputs to use",
|
| 243 |
+
required=False,
|
| 244 |
+
action="store_true",
|
| 245 |
+
)
|
| 246 |
+
run_llm_parser.add_argument(
|
| 247 |
+
"--logging_steps",
|
| 248 |
+
help="Logging steps to use",
|
| 249 |
+
required=False,
|
| 250 |
+
type=int,
|
| 251 |
+
default=-1,
|
| 252 |
+
)
|
| 253 |
+
run_llm_parser.add_argument(
|
| 254 |
+
"--project_name",
|
| 255 |
+
help="Output directory",
|
| 256 |
+
required=False,
|
| 257 |
+
type=str,
|
| 258 |
+
)
|
| 259 |
+
run_llm_parser.add_argument(
|
| 260 |
+
"--evaluation_strategy",
|
| 261 |
+
help="Evaluation strategy to use",
|
| 262 |
+
required=False,
|
| 263 |
+
type=str,
|
| 264 |
+
default="epoch",
|
| 265 |
+
)
|
| 266 |
+
run_llm_parser.add_argument(
|
| 267 |
+
"--save_total_limit",
|
| 268 |
+
help="Save total limit to use",
|
| 269 |
+
required=False,
|
| 270 |
+
type=int,
|
| 271 |
+
default=1,
|
| 272 |
+
)
|
| 273 |
+
run_llm_parser.add_argument(
|
| 274 |
+
"--save_strategy",
|
| 275 |
+
help="Save strategy to use",
|
| 276 |
+
required=False,
|
| 277 |
+
type=str,
|
| 278 |
+
default="epoch",
|
| 279 |
+
)
|
| 280 |
+
run_llm_parser.add_argument(
|
| 281 |
+
"--auto_find_batch_size",
|
| 282 |
+
help="Auto find batch size True/False",
|
| 283 |
+
required=False,
|
| 284 |
+
action="store_true",
|
| 285 |
+
)
|
| 286 |
+
run_llm_parser.add_argument(
|
| 287 |
+
"--fp16",
|
| 288 |
+
help="FP16 True/False",
|
| 289 |
+
required=False,
|
| 290 |
+
action="store_true",
|
| 291 |
+
)
|
| 292 |
+
run_llm_parser.add_argument(
|
| 293 |
+
"--push_to_hub",
|
| 294 |
+
help="Push to hub True/False",
|
| 295 |
+
required=False,
|
| 296 |
+
action="store_true",
|
| 297 |
+
)
|
| 298 |
+
run_llm_parser.add_argument(
|
| 299 |
+
"--use_int8",
|
| 300 |
+
help="Use int8 True/False",
|
| 301 |
+
required=False,
|
| 302 |
+
action="store_true",
|
| 303 |
+
)
|
| 304 |
+
run_llm_parser.add_argument(
|
| 305 |
+
"--model_max_length",
|
| 306 |
+
help="Model max length to use",
|
| 307 |
+
required=False,
|
| 308 |
+
type=int,
|
| 309 |
+
default=1024,
|
| 310 |
+
)
|
| 311 |
+
run_llm_parser.add_argument(
|
| 312 |
+
"--repo_id",
|
| 313 |
+
help="Repo id for hugging face hub",
|
| 314 |
+
required=False,
|
| 315 |
+
type=str,
|
| 316 |
+
)
|
| 317 |
+
run_llm_parser.add_argument(
|
| 318 |
+
"--use_int4",
|
| 319 |
+
help="Use int4 True/False",
|
| 320 |
+
required=False,
|
| 321 |
+
action="store_true",
|
| 322 |
+
)
|
| 323 |
+
run_llm_parser.add_argument(
|
| 324 |
+
"--trainer",
|
| 325 |
+
help="Trainer type to use",
|
| 326 |
+
required=False,
|
| 327 |
+
type=str,
|
| 328 |
+
default="default",
|
| 329 |
+
)
|
| 330 |
+
run_llm_parser.add_argument(
|
| 331 |
+
"--target_modules",
|
| 332 |
+
help="Target modules to use",
|
| 333 |
+
required=False,
|
| 334 |
+
type=str,
|
| 335 |
+
default=None,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
run_llm_parser.set_defaults(func=run_llm_command_factory)
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
train,
|
| 343 |
+
deploy,
|
| 344 |
+
inference,
|
| 345 |
+
data_path,
|
| 346 |
+
train_split,
|
| 347 |
+
valid_split,
|
| 348 |
+
text_column,
|
| 349 |
+
model,
|
| 350 |
+
learning_rate,
|
| 351 |
+
num_train_epochs,
|
| 352 |
+
train_batch_size,
|
| 353 |
+
eval_batch_size,
|
| 354 |
+
warmup_ratio,
|
| 355 |
+
gradient_accumulation_steps,
|
| 356 |
+
optimizer,
|
| 357 |
+
scheduler,
|
| 358 |
+
weight_decay,
|
| 359 |
+
max_grad_norm,
|
| 360 |
+
seed,
|
| 361 |
+
add_eos_token,
|
| 362 |
+
block_size,
|
| 363 |
+
use_peft,
|
| 364 |
+
lora_r,
|
| 365 |
+
lora_alpha,
|
| 366 |
+
lora_dropout,
|
| 367 |
+
training_type,
|
| 368 |
+
train_on_inputs,
|
| 369 |
+
logging_steps,
|
| 370 |
+
project_name,
|
| 371 |
+
evaluation_strategy,
|
| 372 |
+
save_total_limit,
|
| 373 |
+
save_strategy,
|
| 374 |
+
auto_find_batch_size,
|
| 375 |
+
fp16,
|
| 376 |
+
push_to_hub,
|
| 377 |
+
use_int8,
|
| 378 |
+
model_max_length,
|
| 379 |
+
repo_id,
|
| 380 |
+
use_int4,
|
| 381 |
+
trainer,
|
| 382 |
+
target_modules,
|
| 383 |
+
):
|
| 384 |
+
self.train = train
|
| 385 |
+
self.deploy = deploy
|
| 386 |
+
self.inference = inference
|
| 387 |
+
self.data_path = data_path
|
| 388 |
+
self.train_split = train_split
|
| 389 |
+
self.valid_split = valid_split
|
| 390 |
+
self.text_column = text_column
|
| 391 |
+
self.model = model
|
| 392 |
+
self.learning_rate = learning_rate
|
| 393 |
+
self.num_train_epochs = num_train_epochs
|
| 394 |
+
self.train_batch_size = train_batch_size
|
| 395 |
+
self.eval_batch_size = eval_batch_size
|
| 396 |
+
self.warmup_ratio = warmup_ratio
|
| 397 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 398 |
+
self.optimizer = optimizer
|
| 399 |
+
self.scheduler = scheduler
|
| 400 |
+
self.weight_decay = weight_decay
|
| 401 |
+
self.max_grad_norm = max_grad_norm
|
| 402 |
+
self.seed = seed
|
| 403 |
+
self.add_eos_token = add_eos_token
|
| 404 |
+
self.block_size = block_size
|
| 405 |
+
self.use_peft = use_peft
|
| 406 |
+
self.lora_r = lora_r
|
| 407 |
+
self.lora_alpha = lora_alpha
|
| 408 |
+
self.lora_dropout = lora_dropout
|
| 409 |
+
self.training_type = training_type
|
| 410 |
+
self.train_on_inputs = train_on_inputs
|
| 411 |
+
self.logging_steps = logging_steps
|
| 412 |
+
self.project_name = project_name
|
| 413 |
+
self.evaluation_strategy = evaluation_strategy
|
| 414 |
+
self.save_total_limit = save_total_limit
|
| 415 |
+
self.save_strategy = save_strategy
|
| 416 |
+
self.auto_find_batch_size = auto_find_batch_size
|
| 417 |
+
self.fp16 = fp16
|
| 418 |
+
self.push_to_hub = push_to_hub
|
| 419 |
+
self.use_int8 = use_int8
|
| 420 |
+
self.model_max_length = model_max_length
|
| 421 |
+
self.repo_id = repo_id
|
| 422 |
+
self.use_int4 = use_int4
|
| 423 |
+
self.trainer = trainer
|
| 424 |
+
self.target_modules = target_modules
|
| 425 |
+
|
| 426 |
+
if self.train:
|
| 427 |
+
if self.project_name is None:
|
| 428 |
+
raise ValueError("Project name must be specified")
|
| 429 |
+
if self.data_path is None:
|
| 430 |
+
raise ValueError("Data path must be specified")
|
| 431 |
+
if self.model is None:
|
| 432 |
+
raise ValueError("Model must be specified")
|
| 433 |
+
if self.push_to_hub:
|
| 434 |
+
if self.repo_id is None:
|
| 435 |
+
raise ValueError("Repo id must be specified for push to hub")
|
| 436 |
+
|
| 437 |
+
if self.inference:
|
| 438 |
+
tgi = TextGenerationInference(self.project_name, use_int4=self.use_int4, use_int8=self.use_int8)
|
| 439 |
+
while True:
|
| 440 |
+
prompt = input("User: ")
|
| 441 |
+
if prompt == "exit()":
|
| 442 |
+
break
|
| 443 |
+
print(f"Bot: {tgi.chat(prompt)}")
|
| 444 |
+
|
| 445 |
+
def run(self):
|
| 446 |
+
logger.info("Running LLM")
|
| 447 |
+
logger.info(f"Train: {self.train}")
|
| 448 |
+
if self.train:
|
| 449 |
+
params = LLMTrainingParams(
|
| 450 |
+
model_name=self.model,
|
| 451 |
+
data_path=self.data_path,
|
| 452 |
+
train_split=self.train_split,
|
| 453 |
+
valid_split=self.valid_split,
|
| 454 |
+
text_column=self.text_column,
|
| 455 |
+
learning_rate=self.learning_rate,
|
| 456 |
+
num_train_epochs=self.num_train_epochs,
|
| 457 |
+
train_batch_size=self.train_batch_size,
|
| 458 |
+
eval_batch_size=self.eval_batch_size,
|
| 459 |
+
warmup_ratio=self.warmup_ratio,
|
| 460 |
+
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
| 461 |
+
optimizer=self.optimizer,
|
| 462 |
+
scheduler=self.scheduler,
|
| 463 |
+
weight_decay=self.weight_decay,
|
| 464 |
+
max_grad_norm=self.max_grad_norm,
|
| 465 |
+
seed=self.seed,
|
| 466 |
+
add_eos_token=self.add_eos_token,
|
| 467 |
+
block_size=self.block_size,
|
| 468 |
+
use_peft=self.use_peft,
|
| 469 |
+
lora_r=self.lora_r,
|
| 470 |
+
lora_alpha=self.lora_alpha,
|
| 471 |
+
lora_dropout=self.lora_dropout,
|
| 472 |
+
training_type=self.training_type,
|
| 473 |
+
train_on_inputs=self.train_on_inputs,
|
| 474 |
+
logging_steps=self.logging_steps,
|
| 475 |
+
project_name=self.project_name,
|
| 476 |
+
evaluation_strategy=self.evaluation_strategy,
|
| 477 |
+
save_total_limit=self.save_total_limit,
|
| 478 |
+
save_strategy=self.save_strategy,
|
| 479 |
+
auto_find_batch_size=self.auto_find_batch_size,
|
| 480 |
+
fp16=self.fp16,
|
| 481 |
+
push_to_hub=self.push_to_hub,
|
| 482 |
+
use_int8=self.use_int8,
|
| 483 |
+
model_max_length=self.model_max_length,
|
| 484 |
+
repo_id=self.repo_id,
|
| 485 |
+
use_int4=self.use_int4,
|
| 486 |
+
trainer=self.trainer,
|
| 487 |
+
target_modules=self.target_modules,
|
| 488 |
+
)
|
| 489 |
+
train_llm(params)
|
autotrain-advanced/src/autotrain/cli/run_setup.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
|
| 4 |
+
from loguru import logger
|
| 5 |
+
|
| 6 |
+
from . import BaseAutoTrainCommand
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_app_command_factory(args):
|
| 10 |
+
return RunSetupCommand(args.update_torch)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RunSetupCommand(BaseAutoTrainCommand):
|
| 14 |
+
@staticmethod
|
| 15 |
+
def register_subcommand(parser: ArgumentParser):
|
| 16 |
+
run_setup_parser = parser.add_parser(
|
| 17 |
+
"setup",
|
| 18 |
+
description="✨ Run AutoTrain setup",
|
| 19 |
+
)
|
| 20 |
+
run_setup_parser.add_argument(
|
| 21 |
+
"--update-torch",
|
| 22 |
+
action="store_true",
|
| 23 |
+
help="Update PyTorch to latest version",
|
| 24 |
+
)
|
| 25 |
+
run_setup_parser.set_defaults(func=run_app_command_factory)
|
| 26 |
+
|
| 27 |
+
def __init__(self, update_torch: bool):
|
| 28 |
+
self.update_torch = update_torch
|
| 29 |
+
|
| 30 |
+
def run(self):
|
| 31 |
+
# install latest transformers
|
| 32 |
+
cmd = "pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.git"
|
| 33 |
+
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 34 |
+
logger.info("Installing latest transformers@main")
|
| 35 |
+
_, _ = pipe.communicate()
|
| 36 |
+
logger.info("Successfully installed latest transformers")
|
| 37 |
+
|
| 38 |
+
cmd = "pip uninstall -y peft && pip install git+https://github.com/huggingface/peft.git"
|
| 39 |
+
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 40 |
+
logger.info("Installing latest peft@main")
|
| 41 |
+
_, _ = pipe.communicate()
|
| 42 |
+
logger.info("Successfully installed latest peft")
|
| 43 |
+
|
| 44 |
+
cmd = "pip uninstall -y diffusers && pip install git+https://github.com/huggingface/diffusers.git"
|
| 45 |
+
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 46 |
+
logger.info("Installing latest diffusers@main")
|
| 47 |
+
_, _ = pipe.communicate()
|
| 48 |
+
logger.info("Successfully installed latest diffusers")
|
| 49 |
+
|
| 50 |
+
cmd = "pip uninstall -y trl && pip install git+https://github.com/lvwerra/trl.git"
|
| 51 |
+
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 52 |
+
logger.info("Installing latest trl@main")
|
| 53 |
+
_, _ = pipe.communicate()
|
| 54 |
+
logger.info("Successfully installed latest trl")
|
| 55 |
+
|
| 56 |
+
if self.update_torch:
|
| 57 |
+
cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
|
| 58 |
+
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 59 |
+
logger.info("Installing latest PyTorch")
|
| 60 |
+
_, _ = pipe.communicate()
|
| 61 |
+
logger.info("Successfully installed latest PyTorch")
|
autotrain-advanced/src/autotrain/config.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from loguru import logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
AUTOTRAIN_BACKEND_API = os.getenv("AUTOTRAIN_BACKEND_API", "https://api.autotrain.huggingface.co")
|
| 8 |
+
|
| 9 |
+
HF_API = os.getenv("HF_API", "https://huggingface.co")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logger.configure(handlers=[dict(sink=sys.stderr, format="> <level>{level:<7} {message}</level>")])
|
autotrain-advanced/src/autotrain/dataset.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import zipfile
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
from autotrain.preprocessor.dreambooth import DreamboothPreprocessor
|
| 11 |
+
from autotrain.preprocessor.tabular import (
|
| 12 |
+
TabularBinaryClassificationPreprocessor,
|
| 13 |
+
TabularMultiClassClassificationPreprocessor,
|
| 14 |
+
TabularSingleColumnRegressionPreprocessor,
|
| 15 |
+
)
|
| 16 |
+
from autotrain.preprocessor.text import (
|
| 17 |
+
LLMPreprocessor,
|
| 18 |
+
TextBinaryClassificationPreprocessor,
|
| 19 |
+
TextMultiClassClassificationPreprocessor,
|
| 20 |
+
TextSingleColumnRegressionPreprocessor,
|
| 21 |
+
)
|
| 22 |
+
from autotrain.preprocessor.vision import ImageClassificationPreprocessor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def remove_non_image_files(folder):
|
| 26 |
+
# Define allowed image file extensions
|
| 27 |
+
allowed_extensions = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
|
| 28 |
+
|
| 29 |
+
# Iterate through all files in the folder
|
| 30 |
+
for root, dirs, files in os.walk(folder):
|
| 31 |
+
for file in files:
|
| 32 |
+
# Get the file extension
|
| 33 |
+
file_extension = os.path.splitext(file)[1]
|
| 34 |
+
|
| 35 |
+
# If the file extension is not in the allowed list, remove the file
|
| 36 |
+
if file_extension.lower() not in allowed_extensions:
|
| 37 |
+
file_path = os.path.join(root, file)
|
| 38 |
+
os.remove(file_path)
|
| 39 |
+
print(f"Removed file: {file_path}")
|
| 40 |
+
|
| 41 |
+
# Recursively call the function on each subfolder
|
| 42 |
+
for subfolder in dirs:
|
| 43 |
+
remove_non_image_files(os.path.join(root, subfolder))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class AutoTrainDreamboothDataset:
|
| 48 |
+
concept_images: List[Any]
|
| 49 |
+
concept_name: str
|
| 50 |
+
token: str
|
| 51 |
+
project_name: str
|
| 52 |
+
username: str
|
| 53 |
+
|
| 54 |
+
def __str__(self) -> str:
|
| 55 |
+
info = f"Dataset: {self.project_name} ({self.task})\n"
|
| 56 |
+
return info
|
| 57 |
+
|
| 58 |
+
def __post_init__(self):
|
| 59 |
+
self.task = "dreambooth"
|
| 60 |
+
logger.info(self.__str__())
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def num_samples(self):
|
| 64 |
+
return len(self.concept_images)
|
| 65 |
+
|
| 66 |
+
def prepare(self):
|
| 67 |
+
preprocessor = DreamboothPreprocessor(
|
| 68 |
+
concept_images=self.concept_images,
|
| 69 |
+
concept_name=self.concept_name,
|
| 70 |
+
token=self.token,
|
| 71 |
+
project_name=self.project_name,
|
| 72 |
+
username=self.username,
|
| 73 |
+
)
|
| 74 |
+
preprocessor.prepare()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class AutoTrainImageClassificationDataset:
|
| 79 |
+
train_data: str
|
| 80 |
+
token: str
|
| 81 |
+
project_name: str
|
| 82 |
+
username: str
|
| 83 |
+
valid_data: Optional[str] = None
|
| 84 |
+
percent_valid: Optional[float] = None
|
| 85 |
+
|
| 86 |
+
def __str__(self) -> str:
|
| 87 |
+
info = f"Dataset: {self.project_name} ({self.task})\n"
|
| 88 |
+
info += f"Train data: {self.train_data}\n"
|
| 89 |
+
info += f"Valid data: {self.valid_data}\n"
|
| 90 |
+
return info
|
| 91 |
+
|
| 92 |
+
def __post_init__(self):
|
| 93 |
+
self.task = "image_multi_class_classification"
|
| 94 |
+
if not self.valid_data and self.percent_valid is None:
|
| 95 |
+
self.percent_valid = 0.2
|
| 96 |
+
elif self.valid_data and self.percent_valid is not None:
|
| 97 |
+
raise ValueError("You can only specify one of valid_data or percent_valid")
|
| 98 |
+
elif self.valid_data:
|
| 99 |
+
self.percent_valid = 0.0
|
| 100 |
+
logger.info(self.__str__())
|
| 101 |
+
|
| 102 |
+
self.num_files = self._count_files()
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def num_samples(self):
|
| 106 |
+
return self.num_files
|
| 107 |
+
|
| 108 |
+
def _count_files(self):
|
| 109 |
+
num_files = 0
|
| 110 |
+
zip_ref = zipfile.ZipFile(self.train_data, "r")
|
| 111 |
+
for _ in zip_ref.namelist():
|
| 112 |
+
num_files += 1
|
| 113 |
+
if self.valid_data:
|
| 114 |
+
zip_ref = zipfile.ZipFile(self.valid_data, "r")
|
| 115 |
+
for _ in zip_ref.namelist():
|
| 116 |
+
num_files += 1
|
| 117 |
+
return num_files
|
| 118 |
+
|
| 119 |
+
def prepare(self):
|
| 120 |
+
cache_dir = os.environ.get("HF_HOME")
|
| 121 |
+
if not cache_dir:
|
| 122 |
+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
|
| 123 |
+
|
| 124 |
+
random_uuid = uuid.uuid4()
|
| 125 |
+
train_dir = os.path.join(cache_dir, "autotrain", str(random_uuid))
|
| 126 |
+
os.makedirs(train_dir, exist_ok=True)
|
| 127 |
+
zip_ref = zipfile.ZipFile(self.train_data, "r")
|
| 128 |
+
zip_ref.extractall(train_dir)
|
| 129 |
+
# remove the __MACOSX directory
|
| 130 |
+
macosx_dir = os.path.join(train_dir, "__MACOSX")
|
| 131 |
+
if os.path.exists(macosx_dir):
|
| 132 |
+
os.system(f"rm -rf {macosx_dir}")
|
| 133 |
+
remove_non_image_files(train_dir)
|
| 134 |
+
|
| 135 |
+
valid_dir = None
|
| 136 |
+
if self.valid_data:
|
| 137 |
+
random_uuid = uuid.uuid4()
|
| 138 |
+
valid_dir = os.path.join(cache_dir, "autotrain", str(random_uuid))
|
| 139 |
+
os.makedirs(valid_dir, exist_ok=True)
|
| 140 |
+
zip_ref = zipfile.ZipFile(self.valid_data, "r")
|
| 141 |
+
zip_ref.extractall(valid_dir)
|
| 142 |
+
# remove the __MACOSX directory
|
| 143 |
+
macosx_dir = os.path.join(valid_dir, "__MACOSX")
|
| 144 |
+
if os.path.exists(macosx_dir):
|
| 145 |
+
os.system(f"rm -rf {macosx_dir}")
|
| 146 |
+
remove_non_image_files(valid_dir)
|
| 147 |
+
|
| 148 |
+
preprocessor = ImageClassificationPreprocessor(
|
| 149 |
+
train_data=train_dir,
|
| 150 |
+
valid_data=valid_dir,
|
| 151 |
+
token=self.token,
|
| 152 |
+
project_name=self.project_name,
|
| 153 |
+
username=self.username,
|
| 154 |
+
)
|
| 155 |
+
preprocessor.prepare()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class AutoTrainDataset:
|
| 160 |
+
train_data: List[str]
|
| 161 |
+
task: str
|
| 162 |
+
token: str
|
| 163 |
+
project_name: str
|
| 164 |
+
username: str
|
| 165 |
+
column_mapping: Optional[Dict[str, str]] = None
|
| 166 |
+
valid_data: Optional[List[str]] = None
|
| 167 |
+
percent_valid: Optional[float] = None
|
| 168 |
+
|
| 169 |
+
def __str__(self) -> str:
|
| 170 |
+
info = f"Dataset: {self.project_name} ({self.task})\n"
|
| 171 |
+
info += f"Train data: {self.train_data}\n"
|
| 172 |
+
info += f"Valid data: {self.valid_data}\n"
|
| 173 |
+
info += f"Column mapping: {self.column_mapping}\n"
|
| 174 |
+
return info
|
| 175 |
+
|
| 176 |
+
def __post_init__(self):
|
| 177 |
+
if not self.valid_data and self.percent_valid is None:
|
| 178 |
+
self.percent_valid = 0.2
|
| 179 |
+
elif self.valid_data and self.percent_valid is not None:
|
| 180 |
+
raise ValueError("You can only specify one of valid_data or percent_valid")
|
| 181 |
+
elif self.valid_data:
|
| 182 |
+
self.percent_valid = 0.0
|
| 183 |
+
|
| 184 |
+
self.train_df, self.valid_df = self._preprocess_data()
|
| 185 |
+
logger.info(self.__str__())
|
| 186 |
+
|
| 187 |
+
def _preprocess_data(self):
|
| 188 |
+
train_df = []
|
| 189 |
+
for file in self.train_data:
|
| 190 |
+
if isinstance(file, pd.DataFrame):
|
| 191 |
+
train_df.append(file)
|
| 192 |
+
else:
|
| 193 |
+
train_df.append(pd.read_csv(file))
|
| 194 |
+
if len(train_df) > 1:
|
| 195 |
+
train_df = pd.concat(train_df)
|
| 196 |
+
else:
|
| 197 |
+
train_df = train_df[0]
|
| 198 |
+
|
| 199 |
+
valid_df = None
|
| 200 |
+
if len(self.valid_data) > 0:
|
| 201 |
+
valid_df = []
|
| 202 |
+
for file in self.valid_data:
|
| 203 |
+
if isinstance(file, pd.DataFrame):
|
| 204 |
+
valid_df.append(file)
|
| 205 |
+
else:
|
| 206 |
+
valid_df.append(pd.read_csv(file))
|
| 207 |
+
if len(valid_df) > 1:
|
| 208 |
+
valid_df = pd.concat(valid_df)
|
| 209 |
+
else:
|
| 210 |
+
valid_df = valid_df[0]
|
| 211 |
+
return train_df, valid_df
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def num_samples(self):
|
| 215 |
+
return len(self.train_df) + len(self.valid_df) if self.valid_df is not None else len(self.train_df)
|
| 216 |
+
|
| 217 |
+
def prepare(self):
|
| 218 |
+
if self.task == "text_binary_classification":
|
| 219 |
+
text_column = self.column_mapping["text"]
|
| 220 |
+
label_column = self.column_mapping["label"]
|
| 221 |
+
preprocessor = TextBinaryClassificationPreprocessor(
|
| 222 |
+
train_data=self.train_df,
|
| 223 |
+
text_column=text_column,
|
| 224 |
+
label_column=label_column,
|
| 225 |
+
username=self.username,
|
| 226 |
+
project_name=self.project_name,
|
| 227 |
+
valid_data=self.valid_df,
|
| 228 |
+
test_size=self.percent_valid,
|
| 229 |
+
token=self.token,
|
| 230 |
+
seed=42,
|
| 231 |
+
)
|
| 232 |
+
preprocessor.prepare()
|
| 233 |
+
|
| 234 |
+
elif self.task == "text_multi_class_classification":
|
| 235 |
+
text_column = self.column_mapping["text"]
|
| 236 |
+
label_column = self.column_mapping["label"]
|
| 237 |
+
preprocessor = TextMultiClassClassificationPreprocessor(
|
| 238 |
+
train_data=self.train_df,
|
| 239 |
+
text_column=text_column,
|
| 240 |
+
label_column=label_column,
|
| 241 |
+
username=self.username,
|
| 242 |
+
project_name=self.project_name,
|
| 243 |
+
valid_data=self.valid_df,
|
| 244 |
+
test_size=self.percent_valid,
|
| 245 |
+
token=self.token,
|
| 246 |
+
seed=42,
|
| 247 |
+
)
|
| 248 |
+
preprocessor.prepare()
|
| 249 |
+
|
| 250 |
+
elif self.task == "text_single_column_regression":
|
| 251 |
+
text_column = self.column_mapping["text"]
|
| 252 |
+
label_column = self.column_mapping["label"]
|
| 253 |
+
preprocessor = TextSingleColumnRegressionPreprocessor(
|
| 254 |
+
train_data=self.train_df,
|
| 255 |
+
text_column=text_column,
|
| 256 |
+
label_column=label_column,
|
| 257 |
+
username=self.username,
|
| 258 |
+
project_name=self.project_name,
|
| 259 |
+
valid_data=self.valid_df,
|
| 260 |
+
test_size=self.percent_valid,
|
| 261 |
+
token=self.token,
|
| 262 |
+
seed=42,
|
| 263 |
+
)
|
| 264 |
+
preprocessor.prepare()
|
| 265 |
+
|
| 266 |
+
elif self.task == "lm_training":
|
| 267 |
+
text_column = self.column_mapping.get("text", None)
|
| 268 |
+
if text_column is None:
|
| 269 |
+
prompt_column = self.column_mapping["prompt"]
|
| 270 |
+
response_column = self.column_mapping["response"]
|
| 271 |
+
else:
|
| 272 |
+
prompt_column = None
|
| 273 |
+
response_column = None
|
| 274 |
+
context_column = self.column_mapping.get("context", None)
|
| 275 |
+
prompt_start_column = self.column_mapping.get("prompt_start", None)
|
| 276 |
+
preprocessor = LLMPreprocessor(
|
| 277 |
+
train_data=self.train_df,
|
| 278 |
+
text_column=text_column,
|
| 279 |
+
prompt_column=prompt_column,
|
| 280 |
+
response_column=response_column,
|
| 281 |
+
context_column=context_column,
|
| 282 |
+
prompt_start_column=prompt_start_column,
|
| 283 |
+
username=self.username,
|
| 284 |
+
project_name=self.project_name,
|
| 285 |
+
valid_data=self.valid_df,
|
| 286 |
+
test_size=self.percent_valid,
|
| 287 |
+
token=self.token,
|
| 288 |
+
seed=42,
|
| 289 |
+
)
|
| 290 |
+
preprocessor.prepare()
|
| 291 |
+
|
| 292 |
+
elif self.task == "tabular_binary_classification":
|
| 293 |
+
id_column = self.column_mapping["id"]
|
| 294 |
+
label_column = self.column_mapping["label"]
|
| 295 |
+
if len(id_column.strip()) == 0:
|
| 296 |
+
id_column = None
|
| 297 |
+
preprocessor = TabularBinaryClassificationPreprocessor(
|
| 298 |
+
train_data=self.train_df,
|
| 299 |
+
id_column=id_column,
|
| 300 |
+
label_column=label_column,
|
| 301 |
+
username=self.username,
|
| 302 |
+
project_name=self.project_name,
|
| 303 |
+
valid_data=self.valid_df,
|
| 304 |
+
test_size=self.percent_valid,
|
| 305 |
+
token=self.token,
|
| 306 |
+
seed=42,
|
| 307 |
+
)
|
| 308 |
+
preprocessor.prepare()
|
| 309 |
+
elif self.task == "tabular_multi_class_classification":
|
| 310 |
+
id_column = self.column_mapping["id"]
|
| 311 |
+
label_column = self.column_mapping["label"]
|
| 312 |
+
if len(id_column.strip()) == 0:
|
| 313 |
+
id_column = None
|
| 314 |
+
preprocessor = TabularMultiClassClassificationPreprocessor(
|
| 315 |
+
train_data=self.train_df,
|
| 316 |
+
id_column=id_column,
|
| 317 |
+
label_column=label_column,
|
| 318 |
+
username=self.username,
|
| 319 |
+
project_name=self.project_name,
|
| 320 |
+
valid_data=self.valid_df,
|
| 321 |
+
test_size=self.percent_valid,
|
| 322 |
+
token=self.token,
|
| 323 |
+
seed=42,
|
| 324 |
+
)
|
| 325 |
+
preprocessor.prepare()
|
| 326 |
+
elif self.task == "tabular_single_column_regression":
|
| 327 |
+
id_column = self.column_mapping["id"]
|
| 328 |
+
label_column = self.column_mapping["label"]
|
| 329 |
+
if len(id_column.strip()) == 0:
|
| 330 |
+
id_column = None
|
| 331 |
+
preprocessor = TabularSingleColumnRegressionPreprocessor(
|
| 332 |
+
train_data=self.train_df,
|
| 333 |
+
id_column=id_column,
|
| 334 |
+
label_column=label_column,
|
| 335 |
+
username=self.username,
|
| 336 |
+
project_name=self.project_name,
|
| 337 |
+
valid_data=self.valid_df,
|
| 338 |
+
test_size=self.percent_valid,
|
| 339 |
+
token=self.token,
|
| 340 |
+
seed=42,
|
| 341 |
+
)
|
| 342 |
+
preprocessor.prepare()
|
| 343 |
+
else:
|
| 344 |
+
raise ValueError(f"Task {self.task} not supported")
|
autotrain-advanced/src/autotrain/dreambooth_app.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pty
|
| 3 |
+
import random
|
| 4 |
+
import shutil
|
| 5 |
+
import string
|
| 6 |
+
import subprocess
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from huggingface_hub import HfApi, whoami
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ❯ autotrain dreambooth --help
|
| 13 |
+
# usage: autotrain <command> [<args>] dreambooth [-h] --model MODEL [--revision REVISION] [--tokenizer TOKENIZER] --image-path IMAGE_PATH
|
| 14 |
+
# [--class-image-path CLASS_IMAGE_PATH] --prompt PROMPT [--class-prompt CLASS_PROMPT]
|
| 15 |
+
# [--num-class-images NUM_CLASS_IMAGES] [--class-labels-conditioning CLASS_LABELS_CONDITIONING]
|
| 16 |
+
# [--prior-preservation] [--prior-loss-weight PRIOR_LOSS_WEIGHT] --output OUTPUT [--seed SEED]
|
| 17 |
+
# --resolution RESOLUTION [--center-crop] [--train-text-encoder] [--batch-size BATCH_SIZE]
|
| 18 |
+
# [--sample-batch-size SAMPLE_BATCH_SIZE] [--epochs EPOCHS] [--num-steps NUM_STEPS]
|
| 19 |
+
# [--checkpointing-steps CHECKPOINTING_STEPS] [--resume-from-checkpoint RESUME_FROM_CHECKPOINT]
|
| 20 |
+
# [--gradient-accumulation GRADIENT_ACCUMULATION] [--gradient-checkpointing] [--lr LR] [--scale-lr]
|
| 21 |
+
# [--scheduler SCHEDULER] [--warmup-steps WARMUP_STEPS] [--num-cycles NUM_CYCLES] [--lr-power LR_POWER]
|
| 22 |
+
# [--dataloader-num-workers DATALOADER_NUM_WORKERS] [--use-8bit-adam] [--adam-beta1 ADAM_BETA1]
|
| 23 |
+
# [--adam-beta2 ADAM_BETA2] [--adam-weight-decay ADAM_WEIGHT_DECAY] [--adam-epsilon ADAM_EPSILON]
|
| 24 |
+
# [--max-grad-norm MAX_GRAD_NORM] [--allow-tf32]
|
| 25 |
+
# [--prior-generation-precision PRIOR_GENERATION_PRECISION] [--local-rank LOCAL_RANK] [--xformers]
|
| 26 |
+
# [--pre-compute-text-embeddings] [--tokenizer-max-length TOKENIZER_MAX_LENGTH]
|
| 27 |
+
# [--text-encoder-use-attention-mask] [--rank RANK] [--xl] [--fp16] [--bf16] [--hub-token HUB_TOKEN]
|
| 28 |
+
# [--hub-model-id HUB_MODEL_ID] [--push-to-hub] [--validation-prompt VALIDATION_PROMPT]
|
| 29 |
+
# [--num-validation-images NUM_VALIDATION_IMAGES] [--validation-epochs VALIDATION_EPOCHS]
|
| 30 |
+
# [--checkpoints-total-limit CHECKPOINTS_TOTAL_LIMIT] [--validation-images VALIDATION_IMAGES]
|
| 31 |
+
# [--logging]
|
| 32 |
+
|
| 33 |
+
REPO_ID = os.environ.get("REPO_ID")
|
| 34 |
+
ALLOWED_FILE_TYPES = ["png", "jpg", "jpeg"]
|
| 35 |
+
MODELS = [
|
| 36 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 37 |
+
"runwayml/stable-diffusion-v1-5",
|
| 38 |
+
"stabilityai/stable-diffusion-2-1",
|
| 39 |
+
"stabilityai/stable-diffusion-2-1-base",
|
| 40 |
+
]
|
| 41 |
+
WELCOME_TEXT = """
|
| 42 |
+
Welcome to the AutoTrain DreamBooth! This app allows you to train a DreamBooth model using AutoTrain.
|
| 43 |
+
The app runs on HuggingFace Spaces. Your data is not stored anywhere.
|
| 44 |
+
The trained model (LoRA) will be pushed to your HuggingFace Hub account.
|
| 45 |
+
|
| 46 |
+
You need to use your HuggingFace Hub write [token](https://huggingface.co/settings/tokens) to push the model to your account.
|
| 47 |
+
|
| 48 |
+
NOTE: This space requires GPU to train. Please make sure you have GPU enabled in space settings.
|
| 49 |
+
Please make sure to shutdown / pause the space to avoid any additional charges.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
STEPS = """
|
| 53 |
+
1. [Duplicate](https://huggingface.co/spaces/autotrain-projects/dreambooth?duplicate=true) this space
|
| 54 |
+
2. Upgrade the space to GPU
|
| 55 |
+
3. Enter your HuggingFace Hub write token
|
| 56 |
+
4. Upload images and adjust prompt (remember the prompt!)
|
| 57 |
+
5. Click on Train and wait for the training to finish
|
| 58 |
+
6. Go to your HuggingFace Hub account to find the trained model
|
| 59 |
+
|
| 60 |
+
NOTE: For any issues or feature requests, please open an issue [here](https://github.com/huggingface/autotrain-advanced/issues)
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _update_project_name():
|
| 65 |
+
random_project_name = "-".join(
|
| 66 |
+
["".join(random.choices(string.ascii_lowercase + string.digits, k=4)) for _ in range(3)]
|
| 67 |
+
)
|
| 68 |
+
# check if training tracker exists
|
| 69 |
+
if os.path.exists(os.path.join("/tmp", "training")):
|
| 70 |
+
return [
|
| 71 |
+
gr.Text.update(value=random_project_name, visible=True, interactive=True),
|
| 72 |
+
gr.Button.update(interactive=False),
|
| 73 |
+
]
|
| 74 |
+
return [
|
| 75 |
+
gr.Text.update(value=random_project_name, visible=True, interactive=True),
|
| 76 |
+
gr.Button.update(interactive=True),
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def run_command(cmd):
|
| 81 |
+
cmd = [str(c) for c in cmd]
|
| 82 |
+
print(f"Running command: {' '.join(cmd)}")
|
| 83 |
+
master, slave = pty.openpty()
|
| 84 |
+
p = subprocess.Popen(cmd, stdout=slave, stderr=slave)
|
| 85 |
+
os.close(slave)
|
| 86 |
+
|
| 87 |
+
while p.poll() is None:
|
| 88 |
+
try:
|
| 89 |
+
output = os.read(master, 1024).decode()
|
| 90 |
+
except OSError:
|
| 91 |
+
# Handle exception here, e.g. the pty was closed
|
| 92 |
+
break
|
| 93 |
+
else:
|
| 94 |
+
print(output, end="")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _run_training(
|
| 98 |
+
hub_token,
|
| 99 |
+
project_name,
|
| 100 |
+
model,
|
| 101 |
+
images,
|
| 102 |
+
prompt,
|
| 103 |
+
learning_rate,
|
| 104 |
+
num_steps,
|
| 105 |
+
batch_size,
|
| 106 |
+
gradient_accumulation_steps,
|
| 107 |
+
prior_preservation,
|
| 108 |
+
scale_lr,
|
| 109 |
+
use_8bit_adam,
|
| 110 |
+
train_text_encoder,
|
| 111 |
+
gradient_checkpointing,
|
| 112 |
+
center_crop,
|
| 113 |
+
prior_loss_weight,
|
| 114 |
+
num_cycles,
|
| 115 |
+
lr_power,
|
| 116 |
+
adam_beta1,
|
| 117 |
+
adam_beta2,
|
| 118 |
+
adam_weight_decay,
|
| 119 |
+
adam_epsilon,
|
| 120 |
+
max_grad_norm,
|
| 121 |
+
warmup_steps,
|
| 122 |
+
scheduler,
|
| 123 |
+
resolution,
|
| 124 |
+
fp16,
|
| 125 |
+
):
|
| 126 |
+
if REPO_ID == "autotrain-projects/dreambooth":
|
| 127 |
+
return gr.Markdown.update(
|
| 128 |
+
value="❌ Please [duplicate](https://huggingface.co/spaces/autotrain-projects/dreambooth?duplicate=true) this space before training."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
api = HfApi(token=hub_token)
|
| 132 |
+
|
| 133 |
+
if os.path.exists(os.path.join("/tmp", "training")):
|
| 134 |
+
return gr.Markdown.update(value="❌ Another training job is already running in this space.")
|
| 135 |
+
|
| 136 |
+
with open(os.path.join("/tmp", "training"), "w") as f:
|
| 137 |
+
f.write("training")
|
| 138 |
+
|
| 139 |
+
hub_model_id = whoami(token=hub_token)["name"] + "/" + str(project_name).strip()
|
| 140 |
+
|
| 141 |
+
image_path = "/tmp/data"
|
| 142 |
+
os.makedirs(image_path, exist_ok=True)
|
| 143 |
+
output_dir = "/tmp/model"
|
| 144 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 145 |
+
|
| 146 |
+
for image in images:
|
| 147 |
+
shutil.copy(image.name, image_path)
|
| 148 |
+
cmd = [
|
| 149 |
+
"autotrain",
|
| 150 |
+
"dreambooth",
|
| 151 |
+
"--model",
|
| 152 |
+
model,
|
| 153 |
+
"--output",
|
| 154 |
+
output_dir,
|
| 155 |
+
"--image-path",
|
| 156 |
+
image_path,
|
| 157 |
+
"--prompt",
|
| 158 |
+
prompt,
|
| 159 |
+
"--resolution",
|
| 160 |
+
"1024",
|
| 161 |
+
"--batch-size",
|
| 162 |
+
batch_size,
|
| 163 |
+
"--num-steps",
|
| 164 |
+
num_steps,
|
| 165 |
+
"--gradient-accumulation",
|
| 166 |
+
gradient_accumulation_steps,
|
| 167 |
+
"--lr",
|
| 168 |
+
learning_rate,
|
| 169 |
+
"--scheduler",
|
| 170 |
+
scheduler,
|
| 171 |
+
"--warmup-steps",
|
| 172 |
+
warmup_steps,
|
| 173 |
+
"--num-cycles",
|
| 174 |
+
num_cycles,
|
| 175 |
+
"--lr-power",
|
| 176 |
+
lr_power,
|
| 177 |
+
"--adam-beta1",
|
| 178 |
+
adam_beta1,
|
| 179 |
+
"--adam-beta2",
|
| 180 |
+
adam_beta2,
|
| 181 |
+
"--adam-weight-decay",
|
| 182 |
+
adam_weight_decay,
|
| 183 |
+
"--adam-epsilon",
|
| 184 |
+
adam_epsilon,
|
| 185 |
+
"--max-grad-norm",
|
| 186 |
+
max_grad_norm,
|
| 187 |
+
"--prior-loss-weight",
|
| 188 |
+
prior_loss_weight,
|
| 189 |
+
"--push-to-hub",
|
| 190 |
+
"--hub-token",
|
| 191 |
+
hub_token,
|
| 192 |
+
"--hub-model-id",
|
| 193 |
+
hub_model_id,
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
if prior_preservation:
|
| 197 |
+
cmd.append("--prior-preservation")
|
| 198 |
+
if scale_lr:
|
| 199 |
+
cmd.append("--scale-lr")
|
| 200 |
+
if use_8bit_adam:
|
| 201 |
+
cmd.append("--use-8bit-adam")
|
| 202 |
+
if train_text_encoder:
|
| 203 |
+
cmd.append("--train-text-encoder")
|
| 204 |
+
if gradient_checkpointing:
|
| 205 |
+
cmd.append("--gradient-checkpointing")
|
| 206 |
+
if center_crop:
|
| 207 |
+
cmd.append("--center-crop")
|
| 208 |
+
if fp16:
|
| 209 |
+
cmd.append("--fp16")
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
run_command(cmd)
|
| 213 |
+
# delete the training tracker file in /tmp/
|
| 214 |
+
os.remove(os.path.join("/tmp", "training"))
|
| 215 |
+
# switch off space
|
| 216 |
+
if REPO_ID is not None:
|
| 217 |
+
api.pause_space(repo_id=REPO_ID)
|
| 218 |
+
return gr.Markdown.update(value=f"✅ Training finished! Model pushed to {hub_model_id}")
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(e)
|
| 221 |
+
print("Error running command")
|
| 222 |
+
# delete the training tracker file in /tmp/
|
| 223 |
+
os.remove(os.path.join("/tmp", "training"))
|
| 224 |
+
return gr.Markdown.update(value="❌ Error running command. Please try again.")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def main():
|
| 228 |
+
with gr.Blocks(theme="freddyaboulton/dracula_revamped") as demo:
|
| 229 |
+
gr.Markdown("## 🤗 AutoTrain DreamBooth")
|
| 230 |
+
gr.Markdown(WELCOME_TEXT)
|
| 231 |
+
with gr.Accordion("Steps", open=False):
|
| 232 |
+
gr.Markdown(STEPS)
|
| 233 |
+
hub_token = gr.Textbox(
|
| 234 |
+
label="Hub Token",
|
| 235 |
+
value="",
|
| 236 |
+
lines=1,
|
| 237 |
+
max_lines=1,
|
| 238 |
+
interactive=True,
|
| 239 |
+
type="password",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
with gr.Row():
|
| 243 |
+
with gr.Column():
|
| 244 |
+
project_name = gr.Textbox(
|
| 245 |
+
label="Project name",
|
| 246 |
+
value="",
|
| 247 |
+
lines=1,
|
| 248 |
+
max_lines=1,
|
| 249 |
+
interactive=True,
|
| 250 |
+
)
|
| 251 |
+
model = gr.Dropdown(
|
| 252 |
+
label="Model",
|
| 253 |
+
choices=MODELS,
|
| 254 |
+
value=MODELS[0],
|
| 255 |
+
visible=True,
|
| 256 |
+
interactive=True,
|
| 257 |
+
elem_id="model",
|
| 258 |
+
allow_custom_values=True,
|
| 259 |
+
)
|
| 260 |
+
images = gr.File(
|
| 261 |
+
label="Images",
|
| 262 |
+
file_types=ALLOWED_FILE_TYPES,
|
| 263 |
+
file_count="multiple",
|
| 264 |
+
visible=True,
|
| 265 |
+
interactive=True,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
with gr.Column():
|
| 269 |
+
prompt = gr.Textbox(
|
| 270 |
+
label="Prompt",
|
| 271 |
+
placeholder="photo of sks dog",
|
| 272 |
+
lines=1,
|
| 273 |
+
)
|
| 274 |
+
with gr.Row():
|
| 275 |
+
learning_rate = gr.Number(
|
| 276 |
+
label="Learning Rate",
|
| 277 |
+
value=1e-4,
|
| 278 |
+
visible=True,
|
| 279 |
+
interactive=True,
|
| 280 |
+
elem_id="learning_rate",
|
| 281 |
+
)
|
| 282 |
+
num_steps = gr.Number(
|
| 283 |
+
label="Number of Steps",
|
| 284 |
+
value=500,
|
| 285 |
+
visible=True,
|
| 286 |
+
interactive=True,
|
| 287 |
+
elem_id="num_steps",
|
| 288 |
+
precision=0,
|
| 289 |
+
)
|
| 290 |
+
batch_size = gr.Number(
|
| 291 |
+
label="Batch Size",
|
| 292 |
+
value=1,
|
| 293 |
+
visible=True,
|
| 294 |
+
interactive=True,
|
| 295 |
+
elem_id="batch_size",
|
| 296 |
+
precision=0,
|
| 297 |
+
)
|
| 298 |
+
with gr.Row():
|
| 299 |
+
gradient_accumulation_steps = gr.Number(
|
| 300 |
+
label="Gradient Accumulation Steps",
|
| 301 |
+
value=4,
|
| 302 |
+
visible=True,
|
| 303 |
+
interactive=True,
|
| 304 |
+
elem_id="gradient_accumulation_steps",
|
| 305 |
+
precision=0,
|
| 306 |
+
)
|
| 307 |
+
resolution = gr.Number(
|
| 308 |
+
label="Resolution",
|
| 309 |
+
value=1024,
|
| 310 |
+
visible=True,
|
| 311 |
+
interactive=True,
|
| 312 |
+
elem_id="resolution",
|
| 313 |
+
precision=0,
|
| 314 |
+
)
|
| 315 |
+
scheduler = gr.Dropdown(
|
| 316 |
+
label="Scheduler",
|
| 317 |
+
choices=["cosine", "linear", "constant"],
|
| 318 |
+
value="constant",
|
| 319 |
+
visible=True,
|
| 320 |
+
interactive=True,
|
| 321 |
+
elem_id="scheduler",
|
| 322 |
+
)
|
| 323 |
+
with gr.Column():
|
| 324 |
+
with gr.Group():
|
| 325 |
+
fp16 = gr.Checkbox(
|
| 326 |
+
label="FP16",
|
| 327 |
+
value=True,
|
| 328 |
+
visible=True,
|
| 329 |
+
interactive=True,
|
| 330 |
+
elem_id="fp16",
|
| 331 |
+
)
|
| 332 |
+
prior_preservation = gr.Checkbox(
|
| 333 |
+
label="Prior Preservation",
|
| 334 |
+
value=False,
|
| 335 |
+
visible=True,
|
| 336 |
+
interactive=True,
|
| 337 |
+
elem_id="prior_preservation",
|
| 338 |
+
)
|
| 339 |
+
scale_lr = gr.Checkbox(
|
| 340 |
+
label="Scale LR",
|
| 341 |
+
value=False,
|
| 342 |
+
visible=True,
|
| 343 |
+
interactive=True,
|
| 344 |
+
elem_id="scale_lr",
|
| 345 |
+
)
|
| 346 |
+
use_8bit_adam = gr.Checkbox(
|
| 347 |
+
label="Use 8bit Adam",
|
| 348 |
+
value=True,
|
| 349 |
+
visible=True,
|
| 350 |
+
interactive=True,
|
| 351 |
+
elem_id="use_8bit_adam",
|
| 352 |
+
)
|
| 353 |
+
train_text_encoder = gr.Checkbox(
|
| 354 |
+
label="Train Text Encoder",
|
| 355 |
+
value=False,
|
| 356 |
+
visible=True,
|
| 357 |
+
interactive=True,
|
| 358 |
+
elem_id="train_text_encoder",
|
| 359 |
+
)
|
| 360 |
+
gradient_checkpointing = gr.Checkbox(
|
| 361 |
+
label="Gradient Checkpointing",
|
| 362 |
+
value=False,
|
| 363 |
+
visible=True,
|
| 364 |
+
interactive=True,
|
| 365 |
+
elem_id="gradient_checkpointing",
|
| 366 |
+
)
|
| 367 |
+
center_crop = gr.Checkbox(
|
| 368 |
+
label="Center Crop",
|
| 369 |
+
value=False,
|
| 370 |
+
visible=True,
|
| 371 |
+
interactive=True,
|
| 372 |
+
elem_id="center_crop",
|
| 373 |
+
)
|
| 374 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 375 |
+
with gr.Row():
|
| 376 |
+
prior_loss_weight = gr.Number(
|
| 377 |
+
label="Prior Loss Weight",
|
| 378 |
+
value=1.0,
|
| 379 |
+
visible=True,
|
| 380 |
+
interactive=True,
|
| 381 |
+
elem_id="prior_loss_weight",
|
| 382 |
+
)
|
| 383 |
+
num_cycles = gr.Number(
|
| 384 |
+
label="Num Cycles",
|
| 385 |
+
value=1,
|
| 386 |
+
visible=True,
|
| 387 |
+
interactive=True,
|
| 388 |
+
elem_id="num_cycles",
|
| 389 |
+
precision=0,
|
| 390 |
+
)
|
| 391 |
+
lr_power = gr.Number(
|
| 392 |
+
label="LR Power",
|
| 393 |
+
value=1,
|
| 394 |
+
visible=True,
|
| 395 |
+
interactive=True,
|
| 396 |
+
elem_id="lr_power",
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
adam_beta1 = gr.Number(
|
| 400 |
+
label="Adam Beta1",
|
| 401 |
+
value=0.9,
|
| 402 |
+
visible=True,
|
| 403 |
+
interactive=True,
|
| 404 |
+
elem_id="adam_beta1",
|
| 405 |
+
)
|
| 406 |
+
adam_beta2 = gr.Number(
|
| 407 |
+
label="Adam Beta2",
|
| 408 |
+
value=0.999,
|
| 409 |
+
visible=True,
|
| 410 |
+
interactive=True,
|
| 411 |
+
elem_id="adam_beta2",
|
| 412 |
+
)
|
| 413 |
+
adam_weight_decay = gr.Number(
|
| 414 |
+
label="Adam Weight Decay",
|
| 415 |
+
value=1e-2,
|
| 416 |
+
visible=True,
|
| 417 |
+
interactive=True,
|
| 418 |
+
elem_id="adam_weight_decay",
|
| 419 |
+
)
|
| 420 |
+
adam_epsilon = gr.Number(
|
| 421 |
+
label="Adam Epsilon",
|
| 422 |
+
value=1e-8,
|
| 423 |
+
visible=True,
|
| 424 |
+
interactive=True,
|
| 425 |
+
elem_id="adam_epsilon",
|
| 426 |
+
)
|
| 427 |
+
max_grad_norm = gr.Number(
|
| 428 |
+
label="Max Grad Norm",
|
| 429 |
+
value=1,
|
| 430 |
+
visible=True,
|
| 431 |
+
interactive=True,
|
| 432 |
+
elem_id="max_grad_norm",
|
| 433 |
+
)
|
| 434 |
+
warmup_steps = gr.Number(
|
| 435 |
+
label="Warmup Steps",
|
| 436 |
+
value=0,
|
| 437 |
+
visible=True,
|
| 438 |
+
interactive=True,
|
| 439 |
+
elem_id="warmup_steps",
|
| 440 |
+
precision=0,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
train_button = gr.Button(value="Train", elem_id="train")
|
| 444 |
+
output_md = gr.Markdown("## Output")
|
| 445 |
+
inputs = [
|
| 446 |
+
hub_token,
|
| 447 |
+
project_name,
|
| 448 |
+
model,
|
| 449 |
+
images,
|
| 450 |
+
prompt,
|
| 451 |
+
learning_rate,
|
| 452 |
+
num_steps,
|
| 453 |
+
batch_size,
|
| 454 |
+
gradient_accumulation_steps,
|
| 455 |
+
prior_preservation,
|
| 456 |
+
scale_lr,
|
| 457 |
+
use_8bit_adam,
|
| 458 |
+
train_text_encoder,
|
| 459 |
+
gradient_checkpointing,
|
| 460 |
+
center_crop,
|
| 461 |
+
prior_loss_weight,
|
| 462 |
+
num_cycles,
|
| 463 |
+
lr_power,
|
| 464 |
+
adam_beta1,
|
| 465 |
+
adam_beta2,
|
| 466 |
+
adam_weight_decay,
|
| 467 |
+
adam_epsilon,
|
| 468 |
+
max_grad_norm,
|
| 469 |
+
warmup_steps,
|
| 470 |
+
scheduler,
|
| 471 |
+
resolution,
|
| 472 |
+
fp16,
|
| 473 |
+
]
|
| 474 |
+
|
| 475 |
+
train_button.click(_run_training, inputs=inputs, outputs=output_md)
|
| 476 |
+
demo.load(
|
| 477 |
+
_update_project_name,
|
| 478 |
+
outputs=[project_name, train_button],
|
| 479 |
+
)
|
| 480 |
+
return demo
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
if __name__ == "__main__":
|
| 484 |
+
demo = main()
|
| 485 |
+
demo.launch()
|
autotrain-advanced/src/autotrain/help.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
APP_AUTOTRAIN_USERNAME = """Please choose the user or organization who is creating the AutoTrain Project.
|
| 2 |
+
In case of non-free tier, this user or organization will be billed.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
APP_PROJECT_NAME = """A unique name for the AutoTrain Project.
|
| 6 |
+
This name will be used to identify the project in the AutoTrain dashboard."""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
APP_IMAGE_CLASSIFICATION_DATA_HELP = """The data for the Image Classification task should be in the following format:
|
| 10 |
+
- The data should be in a zip file.
|
| 11 |
+
- The zip file should contain multiple folders (the classes), each folder should contain images of a single class.
|
| 12 |
+
- The name of the folder should be the name of the class.
|
| 13 |
+
- The images must be jpeg, jpg or png.
|
| 14 |
+
- There should be at least 5 images per class.
|
| 15 |
+
- There should not be any other files in the zip file.
|
| 16 |
+
- There should not be any other folders inside the zip folder.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
APP_LM_TRAINING_TYPE = """There are two types of Language Model Training:
|
| 20 |
+
- generic
|
| 21 |
+
- chat
|
| 22 |
+
|
| 23 |
+
In the generic mode, you provide a CSV with a text column which has already been formatted by you for training a language model.
|
| 24 |
+
In the chat mode, you provide a CSV with two or three text columns: prompt, context (optional) and response.
|
| 25 |
+
Context column can be empty for samples if not needed. You can also have a "prompt start" column. If provided, "prompt start" will be prepended before the prompt column.
|
| 26 |
+
|
| 27 |
+
Please see [this](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset which has both formats in the same dataset.
|
| 28 |
+
"""
|
autotrain-advanced/src/autotrain/infer/__init__.py
ADDED
|
File without changes
|
autotrain-advanced/src/autotrain/infer/text_generation.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class TextGenerationInference:
|
| 10 |
+
model_path: str = "gpt2"
|
| 11 |
+
use_int4: Optional[bool] = False
|
| 12 |
+
use_int8: Optional[bool] = False
|
| 13 |
+
temperature: Optional[float] = 1.0
|
| 14 |
+
top_k: Optional[int] = 50
|
| 15 |
+
top_p: Optional[float] = 0.95
|
| 16 |
+
repetition_penalty: Optional[float] = 1.0
|
| 17 |
+
num_return_sequences: Optional[int] = 1
|
| 18 |
+
num_beams: Optional[int] = 1
|
| 19 |
+
max_new_tokens: Optional[int] = 1024
|
| 20 |
+
do_sample: Optional[bool] = True
|
| 21 |
+
|
| 22 |
+
def __post_init__(self):
|
| 23 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
+
self.model_path,
|
| 25 |
+
load_in_4bit=self.use_int4,
|
| 26 |
+
load_in_8bit=self.use_int8,
|
| 27 |
+
torch_dtype=torch.float16,
|
| 28 |
+
trust_remote_code=True,
|
| 29 |
+
device_map="auto",
|
| 30 |
+
)
|
| 31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
| 32 |
+
self.model.eval()
|
| 33 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
+
self.generation_config = GenerationConfig(
|
| 35 |
+
temperature=self.temperature,
|
| 36 |
+
top_k=self.top_k,
|
| 37 |
+
top_p=self.top_p,
|
| 38 |
+
repetition_penalty=self.repetition_penalty,
|
| 39 |
+
num_return_sequences=self.num_return_sequences,
|
| 40 |
+
num_beams=self.num_beams,
|
| 41 |
+
max_length=self.max_new_tokens,
|
| 42 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 43 |
+
do_sample=self.do_sample,
|
| 44 |
+
max_new_tokens=self.max_new_tokens,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def chat(self, prompt):
|
| 48 |
+
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 49 |
+
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
|
| 50 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
autotrain-advanced/src/autotrain/languages.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SUPPORTED_LANGUAGES = [
|
| 2 |
+
"en",
|
| 3 |
+
"ar",
|
| 4 |
+
"bn",
|
| 5 |
+
"de",
|
| 6 |
+
"es",
|
| 7 |
+
"fi",
|
| 8 |
+
"fr",
|
| 9 |
+
"hi",
|
| 10 |
+
"it",
|
| 11 |
+
"ja",
|
| 12 |
+
"ko",
|
| 13 |
+
"nl",
|
| 14 |
+
"pt",
|
| 15 |
+
"sv",
|
| 16 |
+
"tr",
|
| 17 |
+
"zh",
|
| 18 |
+
"unk",
|
| 19 |
+
]
|
autotrain-advanced/src/autotrain/params.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
from autotrain.languages import SUPPORTED_LANGUAGES
|
| 8 |
+
from autotrain.tasks import TASKS
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LoraR:
|
| 12 |
+
TYPE = "int"
|
| 13 |
+
MIN_VALUE = 1
|
| 14 |
+
MAX_VALUE = 100
|
| 15 |
+
DEFAULT = 16
|
| 16 |
+
STEP = 1
|
| 17 |
+
STREAMLIT_INPUT = "number_input"
|
| 18 |
+
PRETTY_NAME = "LoRA R"
|
| 19 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LoraAlpha:
|
| 23 |
+
TYPE = "int"
|
| 24 |
+
MIN_VALUE = 1
|
| 25 |
+
MAX_VALUE = 256
|
| 26 |
+
DEFAULT = 32
|
| 27 |
+
STEP = 1
|
| 28 |
+
STREAMLIT_INPUT = "number_input"
|
| 29 |
+
PRETTY_NAME = "LoRA Alpha"
|
| 30 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LoraDropout:
|
| 34 |
+
TYPE = "float"
|
| 35 |
+
MIN_VALUE = 0.0
|
| 36 |
+
MAX_VALUE = 1.0
|
| 37 |
+
DEFAULT = 0.05
|
| 38 |
+
STEP = 0.01
|
| 39 |
+
STREAMLIT_INPUT = "number_input"
|
| 40 |
+
PRETTY_NAME = "LoRA Dropout"
|
| 41 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LearningRate:
|
| 45 |
+
TYPE = "float"
|
| 46 |
+
MIN_VALUE = 1e-7
|
| 47 |
+
MAX_VALUE = 1e-1
|
| 48 |
+
DEFAULT = 1e-3
|
| 49 |
+
STEP = 1e-6
|
| 50 |
+
FORMAT = "%.2E"
|
| 51 |
+
STREAMLIT_INPUT = "number_input"
|
| 52 |
+
PRETTY_NAME = "Learning Rate"
|
| 53 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class LMLearningRate(LearningRate):
|
| 57 |
+
DEFAULT = 5e-5
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Optimizer:
|
| 61 |
+
TYPE = "str"
|
| 62 |
+
DEFAULT = "adamw_torch"
|
| 63 |
+
CHOICES = ["adamw_torch", "adamw_hf", "sgd", "adafactor", "adagrad"]
|
| 64 |
+
STREAMLIT_INPUT = "selectbox"
|
| 65 |
+
PRETTY_NAME = "Optimizer"
|
| 66 |
+
GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class LMTrainingType:
|
| 70 |
+
TYPE = "str"
|
| 71 |
+
DEFAULT = "generic"
|
| 72 |
+
CHOICES = ["generic", "chat"]
|
| 73 |
+
STREAMLIT_INPUT = "selectbox"
|
| 74 |
+
PRETTY_NAME = "LM Training Type"
|
| 75 |
+
GRAIDO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Scheduler:
|
| 79 |
+
TYPE = "str"
|
| 80 |
+
DEFAULT = "linear"
|
| 81 |
+
CHOICES = ["linear", "cosine"]
|
| 82 |
+
STREAMLIT_INPUT = "selectbox"
|
| 83 |
+
PRETTY_NAME = "Scheduler"
|
| 84 |
+
GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TrainBatchSize:
|
| 88 |
+
TYPE = "int"
|
| 89 |
+
MIN_VALUE = 1
|
| 90 |
+
MAX_VALUE = 128
|
| 91 |
+
DEFAULT = 2
|
| 92 |
+
STEP = 2
|
| 93 |
+
STREAMLIT_INPUT = "number_input"
|
| 94 |
+
PRETTY_NAME = "Train Batch Size"
|
| 95 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class LMTrainBatchSize(TrainBatchSize):
|
| 99 |
+
DEFAULT = 4
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Epochs:
|
| 103 |
+
TYPE = "int"
|
| 104 |
+
MIN_VALUE = 1
|
| 105 |
+
MAX_VALUE = 1000
|
| 106 |
+
DEFAULT = 10
|
| 107 |
+
STREAMLIT_INPUT = "number_input"
|
| 108 |
+
PRETTY_NAME = "Epochs"
|
| 109 |
+
GRADIO_INPUT = gr.Number(value=DEFAULT)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class LMEpochs(Epochs):
|
| 113 |
+
DEFAULT = 1
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class PercentageWarmup:
|
| 117 |
+
TYPE = "float"
|
| 118 |
+
MIN_VALUE = 0.0
|
| 119 |
+
MAX_VALUE = 1.0
|
| 120 |
+
DEFAULT = 0.1
|
| 121 |
+
STEP = 0.01
|
| 122 |
+
STREAMLIT_INPUT = "number_input"
|
| 123 |
+
PRETTY_NAME = "Percentage Warmup"
|
| 124 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=STEP)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class GradientAccumulationSteps:
|
| 128 |
+
TYPE = "int"
|
| 129 |
+
MIN_VALUE = 1
|
| 130 |
+
MAX_VALUE = 100
|
| 131 |
+
DEFAULT = 1
|
| 132 |
+
STREAMLIT_INPUT = "number_input"
|
| 133 |
+
PRETTY_NAME = "Gradient Accumulation Steps"
|
| 134 |
+
GRADIO_INPUT = gr.Number(value=DEFAULT)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class WeightDecay:
|
| 138 |
+
TYPE = "float"
|
| 139 |
+
MIN_VALUE = 0.0
|
| 140 |
+
MAX_VALUE = 1.0
|
| 141 |
+
DEFAULT = 0.0
|
| 142 |
+
STREAMLIT_INPUT = "number_input"
|
| 143 |
+
PRETTY_NAME = "Weight Decay"
|
| 144 |
+
GRADIO_INPUT = gr.Number(value=DEFAULT)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class SourceLanguage:
|
| 148 |
+
TYPE = "str"
|
| 149 |
+
DEFAULT = "en"
|
| 150 |
+
CHOICES = SUPPORTED_LANGUAGES
|
| 151 |
+
STREAMLIT_INPUT = "selectbox"
|
| 152 |
+
PRETTY_NAME = "Source Language"
|
| 153 |
+
GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class TargetLanguage:
|
| 157 |
+
TYPE = "str"
|
| 158 |
+
DEFAULT = "en"
|
| 159 |
+
CHOICES = SUPPORTED_LANGUAGES
|
| 160 |
+
STREAMLIT_INPUT = "selectbox"
|
| 161 |
+
PRETTY_NAME = "Target Language"
|
| 162 |
+
GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class NumModels:
|
| 166 |
+
TYPE = "int"
|
| 167 |
+
MIN_VALUE = 1
|
| 168 |
+
MAX_VALUE = 25
|
| 169 |
+
DEFAULT = 1
|
| 170 |
+
STREAMLIT_INPUT = "number_input"
|
| 171 |
+
PRETTY_NAME = "Number of Models"
|
| 172 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=1)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class DBNumSteps:
|
| 176 |
+
TYPE = "int"
|
| 177 |
+
MIN_VALUE = 100
|
| 178 |
+
MAX_VALUE = 10000
|
| 179 |
+
DEFAULT = 1500
|
| 180 |
+
STREAMLIT_INPUT = "number_input"
|
| 181 |
+
PRETTY_NAME = "Number of Steps"
|
| 182 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=100)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class DBTextEncoderStepsPercentage:
|
| 186 |
+
TYPE = "int"
|
| 187 |
+
MIN_VALUE = 1
|
| 188 |
+
MAX_VALUE = 100
|
| 189 |
+
DEFAULT = 30
|
| 190 |
+
STREAMLIT_INPUT = "number_input"
|
| 191 |
+
PRETTY_NAME = "Text encoder steps percentage"
|
| 192 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=1)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class DBPriorPreservation:
|
| 196 |
+
TYPE = "bool"
|
| 197 |
+
DEFAULT = False
|
| 198 |
+
STREAMLIT_INPUT = "checkbox"
|
| 199 |
+
PRETTY_NAME = "Prior preservation"
|
| 200 |
+
GRADIO_INPUT = gr.Dropdown(["True", "False"], value="False")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class ImageSize:
|
| 204 |
+
TYPE = "int"
|
| 205 |
+
MIN_VALUE = 64
|
| 206 |
+
MAX_VALUE = 2048
|
| 207 |
+
DEFAULT = 512
|
| 208 |
+
STREAMLIT_INPUT = "number_input"
|
| 209 |
+
PRETTY_NAME = "Image Size"
|
| 210 |
+
GRADIO_INPUT = gr.Slider(minimum=MIN_VALUE, maximum=MAX_VALUE, value=DEFAULT, step=64)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class DreamboothConceptType:
|
| 214 |
+
TYPE = "str"
|
| 215 |
+
DEFAULT = "person"
|
| 216 |
+
CHOICES = ["person", "object"]
|
| 217 |
+
STREAMLIT_INPUT = "selectbox"
|
| 218 |
+
PRETTY_NAME = "Concept Type"
|
| 219 |
+
GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class SourceLanguageUnk:
|
| 223 |
+
TYPE = "str"
|
| 224 |
+
DEFAULT = "unk"
|
| 225 |
+
CHOICES = ["unk"]
|
| 226 |
+
STREAMLIT_INPUT = "selectbox"
|
| 227 |
+
PRETTY_NAME = "Source Language"
|
| 228 |
+
GRADIO_INPUT = gr.Dropdown(CHOICES, value=DEFAULT)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class HubModel:
|
| 232 |
+
TYPE = "str"
|
| 233 |
+
DEFAULT = "bert-base-uncased"
|
| 234 |
+
PRETTY_NAME = "Hub Model"
|
| 235 |
+
GRADIO_INPUT = gr.Textbox(lines=1, max_lines=1, label="Hub Model")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class TextBinaryClassificationParams(BaseModel):
|
| 239 |
+
task: Literal["text_binary_classification"]
|
| 240 |
+
learning_rate: float = Field(5e-5, title="Learning rate")
|
| 241 |
+
num_train_epochs: int = Field(3, title="Number of training epochs")
|
| 242 |
+
max_seq_length: int = Field(128, title="Max sequence length")
|
| 243 |
+
train_batch_size: int = Field(32, title="Training batch size")
|
| 244 |
+
warmup_ratio: float = Field(0.1, title="Warmup proportion")
|
| 245 |
+
gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
|
| 246 |
+
optimizer: str = Field("adamw_torch", title="Optimizer")
|
| 247 |
+
scheduler: str = Field("linear", title="Scheduler")
|
| 248 |
+
weight_decay: float = Field(0.0, title="Weight decay")
|
| 249 |
+
max_grad_norm: float = Field(1.0, title="Max gradient norm")
|
| 250 |
+
seed: int = Field(42, title="Seed")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class TextMultiClassClassificationParams(BaseModel):
|
| 254 |
+
task: Literal["text_multi_class_classification"]
|
| 255 |
+
learning_rate: float = Field(5e-5, title="Learning rate")
|
| 256 |
+
num_train_epochs: int = Field(3, title="Number of training epochs")
|
| 257 |
+
max_seq_length: int = Field(128, title="Max sequence length")
|
| 258 |
+
train_batch_size: int = Field(32, title="Training batch size")
|
| 259 |
+
warmup_ratio: float = Field(0.1, title="Warmup proportion")
|
| 260 |
+
gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
|
| 261 |
+
optimizer: str = Field("adamw_torch", title="Optimizer")
|
| 262 |
+
scheduler: str = Field("linear", title="Scheduler")
|
| 263 |
+
weight_decay: float = Field(0.0, title="Weight decay")
|
| 264 |
+
max_grad_norm: float = Field(1.0, title="Max gradient norm")
|
| 265 |
+
seed: int = Field(42, title="Seed")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class DreamboothParams(BaseModel):
|
| 269 |
+
task: Literal["dreambooth"]
|
| 270 |
+
num_steps: int = Field(1500, title="Number of steps")
|
| 271 |
+
image_size: int = Field(512, title="Image size")
|
| 272 |
+
text_encoder_steps_percentage: int = Field(30, title="Text encoder steps percentage")
|
| 273 |
+
prior_preservation: bool = Field(False, title="Prior preservation")
|
| 274 |
+
learning_rate: float = Field(2e-6, title="Learning rate")
|
| 275 |
+
train_batch_size: int = Field(1, title="Training batch size")
|
| 276 |
+
gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class ImageBinaryClassificationParams(BaseModel):
|
| 280 |
+
task: Literal["image_binary_classification"]
|
| 281 |
+
learning_rate: float = Field(3e-5, title="Learning rate")
|
| 282 |
+
num_train_epochs: int = Field(3, title="Number of training epochs")
|
| 283 |
+
train_batch_size: int = Field(8, title="Training batch size")
|
| 284 |
+
warmup_ratio: float = Field(0.1, title="Warmup proportion")
|
| 285 |
+
gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
|
| 286 |
+
optimizer: str = Field("adamw_torch", title="Optimizer")
|
| 287 |
+
scheduler: str = Field("linear", title="Scheduler")
|
| 288 |
+
weight_decay: float = Field(0.0, title="Weight decay")
|
| 289 |
+
max_grad_norm: float = Field(1.0, title="Max gradient norm")
|
| 290 |
+
seed: int = Field(42, title="Seed")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class ImageMultiClassClassificationParams(BaseModel):
|
| 294 |
+
task: Literal["image_multi_class_classification"]
|
| 295 |
+
learning_rate: float = Field(3e-5, title="Learning rate")
|
| 296 |
+
num_train_epochs: int = Field(3, title="Number of training epochs")
|
| 297 |
+
train_batch_size: int = Field(8, title="Training batch size")
|
| 298 |
+
warmup_ratio: float = Field(0.1, title="Warmup proportion")
|
| 299 |
+
gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
|
| 300 |
+
optimizer: str = Field("adamw_torch", title="Optimizer")
|
| 301 |
+
scheduler: str = Field("linear", title="Scheduler")
|
| 302 |
+
weight_decay: float = Field(0.0, title="Weight decay")
|
| 303 |
+
max_grad_norm: float = Field(1.0, title="Max gradient norm")
|
| 304 |
+
seed: int = Field(42, title="Seed")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class LMTrainingParams(BaseModel):
|
| 308 |
+
task: Literal["lm_training"]
|
| 309 |
+
learning_rate: float = Field(3e-5, title="Learning rate")
|
| 310 |
+
num_train_epochs: int = Field(3, title="Number of training epochs")
|
| 311 |
+
train_batch_size: int = Field(8, title="Training batch size")
|
| 312 |
+
warmup_ratio: float = Field(0.1, title="Warmup proportion")
|
| 313 |
+
gradient_accumulation_steps: int = Field(1, title="Gradient accumulation steps")
|
| 314 |
+
optimizer: str = Field("adamw_torch", title="Optimizer")
|
| 315 |
+
scheduler: str = Field("linear", title="Scheduler")
|
| 316 |
+
weight_decay: float = Field(0.0, title="Weight decay")
|
| 317 |
+
max_grad_norm: float = Field(1.0, title="Max gradient norm")
|
| 318 |
+
seed: int = Field(42, title="Seed")
|
| 319 |
+
add_eos_token: bool = Field(True, title="Add EOS token")
|
| 320 |
+
block_size: int = Field(-1, title="Block size")
|
| 321 |
+
lora_r: int = Field(16, title="Lora r")
|
| 322 |
+
lora_alpha: int = Field(32, title="Lora alpha")
|
| 323 |
+
lora_dropout: float = Field(0.05, title="Lora dropout")
|
| 324 |
+
training_type: str = Field("generic", title="Training type")
|
| 325 |
+
train_on_inputs: bool = Field(False, title="Train on inputs")
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@dataclass
|
| 329 |
+
class Params:
|
| 330 |
+
task: str
|
| 331 |
+
param_choice: str
|
| 332 |
+
model_choice: str
|
| 333 |
+
|
| 334 |
+
def __post_init__(self):
|
| 335 |
+
# task should be one of the keys in TASKS
|
| 336 |
+
if self.task not in TASKS:
|
| 337 |
+
raise ValueError(f"task must be one of {TASKS.keys()}")
|
| 338 |
+
self.task_id = TASKS[self.task]
|
| 339 |
+
|
| 340 |
+
if self.param_choice not in ("autotrain", "manual"):
|
| 341 |
+
raise ValueError("param_choice must be either autotrain or manual")
|
| 342 |
+
|
| 343 |
+
if self.model_choice not in ("autotrain", "hub_model"):
|
| 344 |
+
raise ValueError("model_choice must be either autotrain or hub_model")
|
| 345 |
+
|
| 346 |
+
def _dreambooth(self):
|
| 347 |
+
if self.param_choice == "manual":
|
| 348 |
+
return {
|
| 349 |
+
"hub_model": HubModel,
|
| 350 |
+
"image_size": ImageSize,
|
| 351 |
+
"learning_rate": LearningRate,
|
| 352 |
+
"train_batch_size": TrainBatchSize,
|
| 353 |
+
"num_steps": DBNumSteps,
|
| 354 |
+
"gradient_accumulation_steps": GradientAccumulationSteps,
|
| 355 |
+
}
|
| 356 |
+
if self.param_choice == "autotrain":
|
| 357 |
+
if self.model_choice == "hub_model":
|
| 358 |
+
return {
|
| 359 |
+
"hub_model": HubModel,
|
| 360 |
+
"image_size": ImageSize,
|
| 361 |
+
"num_models": NumModels,
|
| 362 |
+
}
|
| 363 |
+
else:
|
| 364 |
+
return {
|
| 365 |
+
"num_models": NumModels,
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
def _tabular_binary_classification(self):
|
| 369 |
+
return {
|
| 370 |
+
"num_models": NumModels,
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
def _lm_training(self):
|
| 374 |
+
if self.param_choice == "manual":
|
| 375 |
+
return {
|
| 376 |
+
"hub_model": HubModel,
|
| 377 |
+
"learning_rate": LMLearningRate,
|
| 378 |
+
"optimizer": Optimizer,
|
| 379 |
+
"scheduler": Scheduler,
|
| 380 |
+
"train_batch_size": LMTrainBatchSize,
|
| 381 |
+
"num_train_epochs": LMEpochs,
|
| 382 |
+
"percentage_warmup": PercentageWarmup,
|
| 383 |
+
"gradient_accumulation_steps": GradientAccumulationSteps,
|
| 384 |
+
"weight_decay": WeightDecay,
|
| 385 |
+
"lora_r": LoraR,
|
| 386 |
+
"lora_alpha": LoraAlpha,
|
| 387 |
+
"lora_dropout": LoraDropout,
|
| 388 |
+
"training_type": LMTrainingType,
|
| 389 |
+
}
|
| 390 |
+
if self.param_choice == "autotrain":
|
| 391 |
+
if self.model_choice == "autotrain":
|
| 392 |
+
return {
|
| 393 |
+
"num_models": NumModels,
|
| 394 |
+
"training_type": LMTrainingType,
|
| 395 |
+
}
|
| 396 |
+
else:
|
| 397 |
+
return {
|
| 398 |
+
"hub_model": HubModel,
|
| 399 |
+
"num_models": NumModels,
|
| 400 |
+
"training_type": LMTrainingType,
|
| 401 |
+
}
|
| 402 |
+
raise ValueError("param_choice must be either autotrain or manual")
|
| 403 |
+
|
| 404 |
+
def _tabular_multi_class_classification(self):
|
| 405 |
+
return self._tabular_binary_classification()
|
| 406 |
+
|
| 407 |
+
def _tabular_single_column_regression(self):
|
| 408 |
+
return self._tabular_binary_classification()
|
| 409 |
+
|
| 410 |
+
def tabular_multi_label_classification(self):
|
| 411 |
+
return self._tabular_binary_classification()
|
| 412 |
+
|
| 413 |
+
def _text_binary_classification(self):
|
| 414 |
+
if self.param_choice == "manual":
|
| 415 |
+
return {
|
| 416 |
+
"hub_model": HubModel,
|
| 417 |
+
"learning_rate": LearningRate,
|
| 418 |
+
"optimizer": Optimizer,
|
| 419 |
+
"scheduler": Scheduler,
|
| 420 |
+
"train_batch_size": TrainBatchSize,
|
| 421 |
+
"num_train_epochs": Epochs,
|
| 422 |
+
"percentage_warmup": PercentageWarmup,
|
| 423 |
+
"gradient_accumulation_steps": GradientAccumulationSteps,
|
| 424 |
+
"weight_decay": WeightDecay,
|
| 425 |
+
}
|
| 426 |
+
if self.param_choice == "autotrain":
|
| 427 |
+
if self.model_choice == "autotrain":
|
| 428 |
+
return {
|
| 429 |
+
"source_language": SourceLanguage,
|
| 430 |
+
"num_models": NumModels,
|
| 431 |
+
}
|
| 432 |
+
return {
|
| 433 |
+
"hub_model": HubModel,
|
| 434 |
+
"source_language": SourceLanguageUnk,
|
| 435 |
+
"num_models": NumModels,
|
| 436 |
+
}
|
| 437 |
+
raise ValueError("param_choice must be either autotrain or manual")
|
| 438 |
+
|
| 439 |
+
def _text_multi_class_classification(self):
|
| 440 |
+
return self._text_binary_classification()
|
| 441 |
+
|
| 442 |
+
def _text_entity_extraction(self):
|
| 443 |
+
return self._text_binary_classification()
|
| 444 |
+
|
| 445 |
+
def _text_single_column_regression(self):
|
| 446 |
+
return self._text_binary_classification()
|
| 447 |
+
|
| 448 |
+
def _text_natural_language_inference(self):
|
| 449 |
+
return self._text_binary_classification()
|
| 450 |
+
|
| 451 |
+
def _image_binary_classification(self):
|
| 452 |
+
if self.param_choice == "manual":
|
| 453 |
+
return {
|
| 454 |
+
"hub_model": HubModel,
|
| 455 |
+
"learning_rate": LearningRate,
|
| 456 |
+
"optimizer": Optimizer,
|
| 457 |
+
"scheduler": Scheduler,
|
| 458 |
+
"train_batch_size": TrainBatchSize,
|
| 459 |
+
"num_train_epochs": Epochs,
|
| 460 |
+
"percentage_warmup": PercentageWarmup,
|
| 461 |
+
"gradient_accumulation_steps": GradientAccumulationSteps,
|
| 462 |
+
"weight_decay": WeightDecay,
|
| 463 |
+
}
|
| 464 |
+
if self.param_choice == "autotrain":
|
| 465 |
+
if self.model_choice == "autotrain":
|
| 466 |
+
return {
|
| 467 |
+
"num_models": NumModels,
|
| 468 |
+
}
|
| 469 |
+
return {
|
| 470 |
+
"hub_model": HubModel,
|
| 471 |
+
"num_models": NumModels,
|
| 472 |
+
}
|
| 473 |
+
raise ValueError("param_choice must be either autotrain or manual")
|
| 474 |
+
|
| 475 |
+
def _image_multi_class_classification(self):
|
| 476 |
+
return self._image_binary_classification()
|
| 477 |
+
|
| 478 |
+
def get(self):
|
| 479 |
+
if self.task in ("text_binary_classification", "text_multi_class_classification"):
|
| 480 |
+
return self._text_binary_classification()
|
| 481 |
+
|
| 482 |
+
if self.task == "text_entity_extraction":
|
| 483 |
+
return self._text_entity_extraction()
|
| 484 |
+
|
| 485 |
+
if self.task == "text_single_column_regression":
|
| 486 |
+
return self._text_single_column_regression()
|
| 487 |
+
|
| 488 |
+
if self.task == "text_natural_language_inference":
|
| 489 |
+
return self._text_natural_language_inference()
|
| 490 |
+
|
| 491 |
+
if self.task == "tabular_binary_classification":
|
| 492 |
+
return self._tabular_binary_classification()
|
| 493 |
+
|
| 494 |
+
if self.task == "tabular_multi_class_classification":
|
| 495 |
+
return self._tabular_multi_class_classification()
|
| 496 |
+
|
| 497 |
+
if self.task == "tabular_single_column_regression":
|
| 498 |
+
return self._tabular_single_column_regression()
|
| 499 |
+
|
| 500 |
+
if self.task == "tabular_multi_label_classification":
|
| 501 |
+
return self.tabular_multi_label_classification()
|
| 502 |
+
|
| 503 |
+
if self.task in ("image_binary_classification", "image_multi_class_classification"):
|
| 504 |
+
return self._image_binary_classification()
|
| 505 |
+
|
| 506 |
+
if self.task == "dreambooth":
|
| 507 |
+
return self._dreambooth()
|
| 508 |
+
|
| 509 |
+
if self.task == "lm_training":
|
| 510 |
+
return self._lm_training()
|
| 511 |
+
|
| 512 |
+
raise ValueError(f"task {self.task} not supported")
|
autotrain-advanced/src/autotrain/preprocessor/__init__.py
ADDED
|
File without changes
|
autotrain-advanced/src/autotrain/preprocessor/dreambooth.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import json
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, List
|
| 5 |
+
|
| 6 |
+
from huggingface_hub import HfApi, create_repo
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class DreamboothPreprocessor:
|
| 12 |
+
concept_images: List[Any]
|
| 13 |
+
concept_name: str
|
| 14 |
+
username: str
|
| 15 |
+
project_name: str
|
| 16 |
+
token: str
|
| 17 |
+
|
| 18 |
+
def __post_init__(self):
|
| 19 |
+
self.repo_name = f"{self.username}/autotrain-data-{self.project_name}"
|
| 20 |
+
try:
|
| 21 |
+
create_repo(
|
| 22 |
+
repo_id=self.repo_name,
|
| 23 |
+
repo_type="dataset",
|
| 24 |
+
token=self.token,
|
| 25 |
+
private=True,
|
| 26 |
+
exist_ok=False,
|
| 27 |
+
)
|
| 28 |
+
except Exception:
|
| 29 |
+
logger.error("Error creating repo")
|
| 30 |
+
raise ValueError("Error creating repo")
|
| 31 |
+
|
| 32 |
+
def _upload_concept_images(self, file, api):
|
| 33 |
+
logger.info(f"Uploading {file} to concept1")
|
| 34 |
+
api.upload_file(
|
| 35 |
+
path_or_fileobj=file.name,
|
| 36 |
+
path_in_repo=f"concept1/{file.name.split('/')[-1]}",
|
| 37 |
+
repo_id=self.repo_name,
|
| 38 |
+
repo_type="dataset",
|
| 39 |
+
token=self.token,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def _upload_concept_prompts(self, api):
|
| 43 |
+
_prompts = {}
|
| 44 |
+
_prompts["concept1"] = self.concept_name
|
| 45 |
+
|
| 46 |
+
prompts = json.dumps(_prompts)
|
| 47 |
+
prompts = prompts.encode("utf-8")
|
| 48 |
+
prompts = io.BytesIO(prompts)
|
| 49 |
+
api.upload_file(
|
| 50 |
+
path_or_fileobj=prompts,
|
| 51 |
+
path_in_repo="prompts.json",
|
| 52 |
+
repo_id=self.repo_name,
|
| 53 |
+
repo_type="dataset",
|
| 54 |
+
token=self.token,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def prepare(self):
|
| 58 |
+
api = HfApi()
|
| 59 |
+
for _file in self.concept_images:
|
| 60 |
+
self._upload_concept_images(_file, api)
|
| 61 |
+
|
| 62 |
+
self._upload_concept_prompts(api)
|
autotrain-advanced/src/autotrain/preprocessor/tabular.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
RESERVED_COLUMNS = ["autotrain_id", "autotrain_label"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class TabularBinaryClassificationPreprocessor:
|
| 14 |
+
train_data: pd.DataFrame
|
| 15 |
+
label_column: str
|
| 16 |
+
username: str
|
| 17 |
+
project_name: str
|
| 18 |
+
id_column: Optional[str] = None
|
| 19 |
+
valid_data: Optional[pd.DataFrame] = None
|
| 20 |
+
test_size: Optional[float] = 0.2
|
| 21 |
+
seed: Optional[int] = 42
|
| 22 |
+
|
| 23 |
+
def __post_init__(self):
|
| 24 |
+
# check if id_column and label_column are in train_data
|
| 25 |
+
if self.id_column is not None:
|
| 26 |
+
if self.id_column not in self.train_data.columns:
|
| 27 |
+
raise ValueError(f"{self.id_column} not in train data")
|
| 28 |
+
|
| 29 |
+
if self.label_column not in self.train_data.columns:
|
| 30 |
+
raise ValueError(f"{self.label_column} not in train data")
|
| 31 |
+
|
| 32 |
+
# check if id_column and label_column are in valid_data
|
| 33 |
+
if self.valid_data is not None:
|
| 34 |
+
if self.id_column is not None:
|
| 35 |
+
if self.id_column not in self.valid_data.columns:
|
| 36 |
+
raise ValueError(f"{self.id_column} not in valid data")
|
| 37 |
+
if self.label_column not in self.valid_data.columns:
|
| 38 |
+
raise ValueError(f"{self.label_column} not in valid data")
|
| 39 |
+
|
| 40 |
+
# make sure no reserved columns are in train_data or valid_data
|
| 41 |
+
for column in RESERVED_COLUMNS:
|
| 42 |
+
if column in self.train_data.columns:
|
| 43 |
+
raise ValueError(f"{column} is a reserved column name")
|
| 44 |
+
if self.valid_data is not None:
|
| 45 |
+
if column in self.valid_data.columns:
|
| 46 |
+
raise ValueError(f"{column} is a reserved column name")
|
| 47 |
+
|
| 48 |
+
def split(self):
|
| 49 |
+
if self.valid_data is not None:
|
| 50 |
+
return self.train_data, self.valid_data
|
| 51 |
+
else:
|
| 52 |
+
train_df, valid_df = train_test_split(
|
| 53 |
+
self.train_data,
|
| 54 |
+
test_size=self.test_size,
|
| 55 |
+
random_state=self.seed,
|
| 56 |
+
stratify=self.train_data[self.label_column],
|
| 57 |
+
)
|
| 58 |
+
train_df = train_df.reset_index(drop=True)
|
| 59 |
+
valid_df = valid_df.reset_index(drop=True)
|
| 60 |
+
return train_df, valid_df
|
| 61 |
+
|
| 62 |
+
def prepare_columns(self, train_df, valid_df):
|
| 63 |
+
train_df.loc[:, "autotrain_id"] = train_df[self.id_column]
|
| 64 |
+
train_df.loc[:, "autotrain_label"] = train_df[self.label_column]
|
| 65 |
+
valid_df.loc[:, "autotrain_id"] = valid_df[self.id_column]
|
| 66 |
+
valid_df.loc[:, "autotrain_label"] = valid_df[self.label_column]
|
| 67 |
+
|
| 68 |
+
# drop id_column and label_column
|
| 69 |
+
train_df = train_df.drop(columns=[self.id_column, self.label_column])
|
| 70 |
+
valid_df = valid_df.drop(columns=[self.id_column, self.label_column])
|
| 71 |
+
return train_df, valid_df
|
| 72 |
+
|
| 73 |
+
def prepare(self):
|
| 74 |
+
train_df, valid_df = self.split()
|
| 75 |
+
train_df, valid_df = self.prepare_columns(train_df, valid_df)
|
| 76 |
+
train_df = Dataset.from_pandas(train_df)
|
| 77 |
+
valid_df = Dataset.from_pandas(valid_df)
|
| 78 |
+
train_df.push_to_hub(f"{self.username}/autotrain-data-{self.project_name}", split="train", private=True)
|
| 79 |
+
valid_df.push_to_hub(f"{self.username}/autotrain-data-{self.project_name}", split="validation", private=True)
|
| 80 |
+
return train_df, valid_df
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TabularMultiClassClassificationPreprocessor(TabularBinaryClassificationPreprocessor):
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TabularSingleColumnRegressionPreprocessor(TabularBinaryClassificationPreprocessor):
|
| 88 |
+
def split(self):
|
| 89 |
+
if self.valid_data is not None:
|
| 90 |
+
return self.train_data, self.valid_data
|
| 91 |
+
else:
|
| 92 |
+
train_df, valid_df = train_test_split(
|
| 93 |
+
self.train_data,
|
| 94 |
+
test_size=self.test_size,
|
| 95 |
+
random_state=self.seed,
|
| 96 |
+
)
|
| 97 |
+
train_df = train_df.reset_index(drop=True)
|
| 98 |
+
valid_df = valid_df.reset_index(drop=True)
|
| 99 |
+
return train_df, valid_df
|