| from convert_images import convert_heic_to_jpeg_and_remove, convert_png_to_jpeg
|
| from generate_class_images import generate_class_images
|
| from model_loader import load_models
|
| from train import parse_args, training_function
|
| from inference import generate_images
|
|
|
| def main():
|
| args = parse_args()
|
| text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path)
|
|
|
|
|
| convert_heic_to_jpeg_and_remove(args.instance_data_dir)
|
| convert_png_to_jpeg(args.instance_data_dir, args.instance_data_dir)
|
|
|
|
|
| if args.with_prior_preservation:
|
| generate_class_images(
|
| pipeline=StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path),
|
| class_prompt=args.class_prompt,
|
| num_class_images=100,
|
| class_images_dir=args.class_data_dir,
|
| )
|
|
|
|
|
| training_function(args, text_encoder, vae, unet, tokenizer)
|
| pipeline.save_pretrained("./output/")
|
| if __name__ == "__main__":
|
| main()
|
|
|