Hemanth-thunder's picture
End of training
c0551d3
import subprocess
from argparse import ArgumentParser
from loguru import logger
from . import BaseAutoTrainCommand
def run_app_command_factory(args):
return RunSetupCommand(args.update_torch)
class RunSetupCommand(BaseAutoTrainCommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_setup_parser = parser.add_parser(
"setup",
description="✨ Run AutoTrain setup",
)
run_setup_parser.add_argument(
"--update-torch",
action="store_true",
help="Update PyTorch to latest version",
)
run_setup_parser.set_defaults(func=run_app_command_factory)
def __init__(self, update_torch: bool):
self.update_torch = update_torch
def run(self):
# install latest transformers
cmd = "pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.git"
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest transformers@main")
_, _ = pipe.communicate()
logger.info("Successfully installed latest transformers")
cmd = "pip uninstall -y peft && pip install git+https://github.com/huggingface/peft.git"
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest peft@main")
_, _ = pipe.communicate()
logger.info("Successfully installed latest peft")
cmd = "pip uninstall -y diffusers && pip install git+https://github.com/huggingface/diffusers.git"
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest diffusers@main")
_, _ = pipe.communicate()
logger.info("Successfully installed latest diffusers")
cmd = "pip uninstall -y trl && pip install git+https://github.com/lvwerra/trl.git"
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest trl@main")
_, _ = pipe.communicate()
logger.info("Successfully installed latest trl")
if self.update_torch:
cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest PyTorch")
_, _ = pipe.communicate()
logger.info("Successfully installed latest PyTorch")