| | import subprocess |
| | from argparse import ArgumentParser |
| |
|
| | from loguru import logger |
| |
|
| | from . import BaseAutoTrainCommand |
| |
|
| |
|
| | def run_app_command_factory(args): |
| | return RunSetupCommand(args.update_torch) |
| |
|
| |
|
| | class RunSetupCommand(BaseAutoTrainCommand): |
| | @staticmethod |
| | def register_subcommand(parser: ArgumentParser): |
| | run_setup_parser = parser.add_parser( |
| | "setup", |
| | description="✨ Run AutoTrain setup", |
| | ) |
| | run_setup_parser.add_argument( |
| | "--update-torch", |
| | action="store_true", |
| | help="Update PyTorch to latest version", |
| | ) |
| | run_setup_parser.set_defaults(func=run_app_command_factory) |
| |
|
| | def __init__(self, update_torch: bool): |
| | self.update_torch = update_torch |
| |
|
| | def run(self): |
| | |
| | cmd = "pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.git" |
| | pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | logger.info("Installing latest transformers@main") |
| | _, _ = pipe.communicate() |
| | logger.info("Successfully installed latest transformers") |
| |
|
| | cmd = "pip uninstall -y peft && pip install git+https://github.com/huggingface/peft.git" |
| | pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | logger.info("Installing latest peft@main") |
| | _, _ = pipe.communicate() |
| | logger.info("Successfully installed latest peft") |
| |
|
| | cmd = "pip uninstall -y diffusers && pip install git+https://github.com/huggingface/diffusers.git" |
| | pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | logger.info("Installing latest diffusers@main") |
| | _, _ = pipe.communicate() |
| | logger.info("Successfully installed latest diffusers") |
| |
|
| | cmd = "pip uninstall -y trl && pip install git+https://github.com/lvwerra/trl.git" |
| | pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | logger.info("Installing latest trl@main") |
| | _, _ = pipe.communicate() |
| | logger.info("Successfully installed latest trl") |
| |
|
| | if self.update_torch: |
| | cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" |
| | pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| | logger.info("Installing latest PyTorch") |
| | _, _ = pipe.communicate() |
| | logger.info("Successfully installed latest PyTorch") |
| |
|