Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import datetime | |
| import os | |
| import pathlib | |
| import shlex | |
| import shutil | |
| import subprocess | |
| import gradio as gr | |
| import slugify | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import HfApi | |
| from app_upload import ModelUploader | |
| from utils import save_model_card | |
| URL_TO_JOIN_LIBRARY_ORG = 'https://huggingface.co/organizations/realfill-library/share/WctmaLvDHWxnuWoJxagTrzVXbGwxoqoJoG' | |
| class Trainer: | |
| def __init__(self, hf_token: str | None = None): | |
| self.hf_token = hf_token | |
| self.api = HfApi(token=hf_token) | |
| self.model_uploader = ModelUploader(hf_token) | |
| def prepare_dataset(self, reference_images: list, | |
| target_image: Image.Image, target_mask: Image.Image, | |
| train_data_dir: pathlib.Path, output_dir: pathlib.Path) -> None: | |
| shutil.rmtree(train_data_dir, ignore_errors=True) | |
| train_data_dir.mkdir(parents=True) | |
| (train_data_dir / 'ref').mkdir(parents=True) | |
| (train_data_dir / 'target').mkdir(parents=True) | |
| for i, temp_path in enumerate(reference_images): | |
| image = Image.open(temp_path.name) | |
| image = image.convert('RGB') | |
| out_path = train_data_dir / 'ref' / f'{i:03d}.jpg' | |
| image.save(out_path, format='JPEG', quality=100) | |
| target_image = Image.open(target_image[0].name) | |
| target_image = target_image.convert('RGB') | |
| out_path = train_data_dir / 'target' / f'target.jpg' | |
| target_image.save(out_path, format='JPEG', quality=100) | |
| out_path = output_dir / f'target.jpg' | |
| target_image.save(out_path, format='JPEG', quality=100) | |
| target_mask = Image.open(target_mask[0].name) | |
| target_mask = target_mask.convert('L') | |
| out_path = train_data_dir / 'target' / f'mask.jpg' | |
| target_mask.save(out_path, format='JPEG', quality=100) | |
| out_path = output_dir / f'mask.jpg' | |
| target_mask.save(out_path, format='JPEG', quality=100) | |
| def join_library_org(self) -> None: | |
| subprocess.run( | |
| shlex.split( | |
| f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LIBRARY_ORG}' | |
| )) | |
| def run( | |
| self, | |
| reference_images: list | None, | |
| target_image: Image.Image | None, | |
| target_mask: Image.Image | None, | |
| output_model_name: str, | |
| overwrite_existing_model: bool, | |
| base_model: str, | |
| resolution_s: str, | |
| n_steps: int, | |
| unet_learning_rate: float, | |
| text_encoder_learning_rate: float, | |
| lora_rank: int, | |
| lora_dropout: float, | |
| lora_alpha: int, | |
| gradient_accumulation: int, | |
| seed: int, | |
| fp16: bool, | |
| use_8bit_adam: bool, | |
| checkpointing_steps: int, | |
| use_wandb: bool, | |
| validation_steps: int, | |
| upload_to_hub: bool, | |
| use_private_repo: bool, | |
| delete_existing_repo: bool, | |
| upload_to: str, | |
| remove_gpu_after_training: bool, | |
| ) -> str: | |
| if not torch.cuda.is_available(): | |
| raise gr.Error('CUDA is not available.') | |
| if reference_images is None: | |
| raise gr.Error('You need to upload reference images.') | |
| if target_image is None: | |
| raise gr.Error('The instance prompt is missing.') | |
| resolution = int(resolution_s) | |
| if not output_model_name: | |
| timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') | |
| output_model_name = f'realfill-{timestamp}' | |
| output_model_name = slugify.slugify(output_model_name) | |
| repo_dir = pathlib.Path(__file__).parent | |
| output_dir = repo_dir / 'experiments' / output_model_name | |
| if overwrite_existing_model or upload_to_hub: | |
| shutil.rmtree(output_dir, ignore_errors=True) | |
| output_dir.mkdir(parents=True) | |
| train_data_dir = repo_dir / 'training_data' / output_model_name | |
| self.prepare_dataset(reference_images, target_image, target_mask, train_data_dir, output_dir) | |
| if upload_to_hub: | |
| self.join_library_org() | |
| command = f''' | |
| python train_realfill.py \ | |
| --pretrained_model_name_or_path={base_model} \ | |
| --train_data_dir={train_data_dir} \ | |
| --output_dir={output_dir} \ | |
| --resolution={resolution} \ | |
| --train_batch_size=16 \ | |
| --gradient_accumulation_steps={gradient_accumulation} --gradient_checkpointing \ | |
| --unet_learning_rate={unet_learning_rate} \ | |
| --text_encoder_learning_rate={text_encoder_learning_rate} \ | |
| --lr_scheduler=constant \ | |
| --lr_warmup_steps=100 \ | |
| --set_grads_to_none \ | |
| --max_train_steps={n_steps} \ | |
| --checkpointing_steps={checkpointing_steps} \ | |
| --validation_steps={validation_steps} \ | |
| --lora_rank={lora_rank} \ | |
| --lora_dropout={lora_dropout} \ | |
| --lora_alpha={lora_alpha} \ | |
| --seed={seed} | |
| ''' | |
| if fp16: | |
| command += ' --mixed_precision fp16' | |
| if use_8bit_adam: | |
| command += ' --use_8bit_adam' | |
| if use_wandb: | |
| command += ' --report_to wandb' | |
| with open(output_dir / 'train.sh', 'w') as f: | |
| command_s = ' '.join(command.split()) | |
| f.write(command_s) | |
| subprocess.run(shlex.split(command)) | |
| save_model_card(save_dir=output_dir, | |
| base_model=base_model, | |
| target_image=output_dir / 'target.jpg', | |
| target_mask=output_dir / 'mask.jpg') | |
| message = 'Training completed!' | |
| print(message) | |
| if upload_to_hub: | |
| upload_message = self.model_uploader.upload_model( | |
| folder_path=output_dir.as_posix(), | |
| repo_name=output_model_name, | |
| upload_to=upload_to, | |
| private=use_private_repo, | |
| delete_existing_repo=delete_existing_repo) | |
| print(upload_message) | |
| message = message + '\n' + upload_message | |
| if remove_gpu_after_training: | |
| space_id = os.getenv('SPACE_ID') | |
| if space_id: | |
| self.api.request_space_hardware(repo_id=space_id, | |
| hardware='cpu-basic') | |
| return message | |