Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import PIL.Image | |
| import shlex | |
| import shutil | |
| import subprocess | |
| from pathlib import Path | |
| import os | |
| import torch | |
| from tqdm import tqdm | |
| def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: | |
| w, h = image.size | |
| if w == h: | |
| return image | |
| elif w > h: | |
| new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) | |
| new_image.paste(image, (0, (w - h) // 2)) | |
| return new_image | |
| else: | |
| new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) | |
| new_image.paste(image, ((h - w) // 2, 0)) | |
| return new_image | |
| def train_submit( | |
| prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, save_path, mem_impath=None | |
| ): | |
| if not torch.cuda.is_available(): | |
| raise gr.Error('CUDA is not available.') | |
| torch.cuda.empty_cache() | |
| original_prompt = prompt | |
| parameter_group = "cross-attn" | |
| train_batch_size = 4 | |
| if concept_type == 'style': | |
| class_data_dir = f'./data/samples_painting/' | |
| anchor_prompt = f'./assets/painting.txt' | |
| openai_key = '' | |
| elif concept_type == 'object': | |
| os.makedirs('temp', exist_ok=True) | |
| class_data_dir = f'./temp/{anchor_prompt}' | |
| name = save_path.split('/')[-1] | |
| prompt = f'{anchor_prompt}+{prompt}' | |
| assert openai_key is not None | |
| if len(openai_key.split('\n')) > 1: | |
| openai_key = openai_key.split('\n') | |
| with open(f'./temp/{name}.txt', 'w') as f: | |
| for prompt_ in openai_key: | |
| f.write(prompt_.strip()+'\n') | |
| openai_key = '' | |
| anchor_prompt = f'./temp/{name}.txt' | |
| elif concept_type == 'memorization': | |
| os.system("wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_mixup.torchscript.pt -P assets/") | |
| os.makedirs('temp', exist_ok=True) | |
| prompt = f'*+{prompt}' | |
| name = save_path.split('/')[-1] | |
| train_batch_size = 1 | |
| lr = 5e-7 | |
| parameter_group = "full-weight" | |
| assert openai_key is not None | |
| assert mem_impath is not None | |
| if len(openai_key.split('\n')) > 1: | |
| openai_key = openai_key.split('\n') | |
| with open(f'./temp/{name}.txt', 'w') as f: | |
| for prompt_ in openai_key: | |
| f.write(prompt_.strip()+'\n') | |
| openai_key = '' | |
| anchor_prompt = f'./temp/{name}.txt' | |
| else: | |
| anchor_prompt = prompt | |
| print(mem_impath) | |
| image = PIL.Image.open(mem_impath[0][0].name) | |
| image = pad_image(image) | |
| image = image.convert('RGB') | |
| mem_impath = f"./temp/{original_prompt.lower().replace(' ', '')}.jpg" | |
| image.save(mem_impath, format='JPEG', quality=100) | |
| class_data_dir = f"./temp/{original_prompt.lower().replace(' ', '')}" | |
| command = f''' | |
| accelerate launch concept-ablation-diffusers/train.py \ | |
| --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ | |
| --output_dir={save_path} \ | |
| --class_data_dir={class_data_dir} \ | |
| --class_prompt="{anchor_prompt}" \ | |
| --caption_target "{prompt}" \ | |
| --concept_type {concept_type} \ | |
| --resolution=512 \ | |
| --train_batch_size={train_batch_size} \ | |
| --learning_rate={lr} \ | |
| --max_train_steps={iterations} \ | |
| --scale_lr --hflip \ | |
| --parameter_group {parameter_group} \ | |
| --openai_key "{openai_key}" \ | |
| --enable_xformers_memory_efficient_attention --num_class_images 500 | |
| ''' | |
| if concept_type == 'style': | |
| command += f' --noaug' | |
| if concept_type == 'memorization': | |
| command += f' --use_8bit_adam --with_prior_preservation --prior_loss_weight=1.0 --mem_impath {mem_impath}' | |
| with open(f'{save_path}/train.sh', 'w') as f: | |
| command_s = ' '.join(command.split()) | |
| f.write(command_s) | |
| res = subprocess.run(shlex.split(command)) | |
| if res.returncode == 0: | |
| result_message = 'Training Completed!' | |
| else: | |
| result_message = 'Training Failed!' | |
| weight_paths = sorted(Path(save_path).glob('*.bin')) | |
| print(weight_paths) | |
| return gr.update(value=result_message), weight_paths[0] | |
| def inference(model_path, prompt, n_steps, generator): | |
| import sys | |
| sys.path.append('concept-ablation/diffusers/.') | |
| from model_pipeline import CustomDiffusionPipeline | |
| import torch | |
| pipe = CustomDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda") | |
| image1 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0] | |
| pipe.load_model(model_path) | |
| image2 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0] | |
| return image1, image2 |