| from argparse import ArgumentParser |
|
|
| from . import BaseAutoTrainCommand |
|
|
|
|
| def run_app_command_factory(args): |
| return RunAutoTrainAppCommand( |
| args.port, |
| args.host, |
| args.task, |
| ) |
|
|
|
|
| class RunAutoTrainAppCommand(BaseAutoTrainCommand): |
| @staticmethod |
| def register_subcommand(parser: ArgumentParser): |
| run_app_parser = parser.add_parser( |
| "app", |
| description="✨ Run AutoTrain app", |
| ) |
| run_app_parser.add_argument( |
| "--port", |
| type=int, |
| default=7860, |
| help="Port to run the app on", |
| required=False, |
| ) |
| run_app_parser.add_argument( |
| "--host", |
| type=str, |
| default="127.0.0.1", |
| help="Host to run the app on", |
| required=False, |
| ) |
| run_app_parser.add_argument( |
| "--task", |
| type=str, |
| required=False, |
| help="Task to run", |
| ) |
| run_app_parser.set_defaults(func=run_app_command_factory) |
|
|
| def __init__(self, port, host, task): |
| self.port = port |
| self.host = host |
| self.task = task |
|
|
| def run(self): |
| if self.task == "dreambooth": |
| from ..dreambooth_app import main |
| else: |
| from ..app import main |
|
|
| demo = main() |
| demo.queue(concurrency_count=10).launch() |
|
|