| import torch |
| import torchvision.transforms as T |
| from controlnet_aux import HEDdetector |
|
|
| from diffusers.utils import load_image |
| from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel |
| from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline |
|
|
|
|
| controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" |
|
|
| weight_dtype = torch.float16 |
| image_size = 1024 |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| torch.manual_seed(0) |
|
|
| |
| controlnet = PixArtControlNetAdapterModel.from_pretrained( |
| controlnet_repo_id, |
| torch_dtype=weight_dtype, |
| use_safetensors=True, |
| ).to(device) |
|
|
| pipe = PixArtAlphaControlnetPipeline.from_pretrained( |
| "PixArt-alpha/PixArt-XL-2-1024-MS", |
| controlnet=controlnet, |
| torch_dtype=weight_dtype, |
| use_safetensors=True, |
| ).to(device) |
|
|
| images_path = "images" |
| control_image_file = "0_7.jpg" |
|
|
| |
| |
| |
| |
| |
| |
| prompt = "battleship in space, galaxy in background" |
|
|
| control_image_name = control_image_file.split(".")[0] |
|
|
| control_image = load_image(f"{images_path}/{control_image_file}") |
| print(control_image.size) |
| height, width = control_image.size |
|
|
| hed = HEDdetector.from_pretrained("lllyasviel/Annotators") |
|
|
| condition_transform = T.Compose( |
| [ |
| T.Lambda(lambda img: img.convert("RGB")), |
| T.CenterCrop([image_size, image_size]), |
| ] |
| ) |
|
|
| control_image = condition_transform(control_image) |
| hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size) |
|
|
| hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg") |
|
|
| |
| with torch.no_grad(): |
| out = pipe( |
| prompt=prompt, |
| image=hed_edge, |
| num_inference_steps=14, |
| guidance_scale=4.5, |
| height=image_size, |
| width=image_size, |
| ) |
|
|
| out.images[0].save(f"{images_path}//{control_image_name}_output.jpg") |
|
|