Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from inference.utils.prompt_utils import encode_prompt | |
| from inference.utils.image_utils import preprocess_controlnet_image | |
| from inference.lora.lora_layers import init_lora_attn | |
| def run_controlnet_inference( | |
| prompt: str, | |
| control_image: Image.Image, | |
| noise_scheduler, | |
| tokenizer, | |
| text_encoder, | |
| vae, | |
| unet, | |
| controlnet, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| device=None, | |
| weight_dtype=torch.float32 | |
| ): | |
| device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| cond_embeds = encode_prompt(prompt, tokenizer, text_encoder, device) | |
| uncond_embeds = encode_prompt("", tokenizer, text_encoder, device) | |
| prompt_embeds = torch.cat([uncond_embeds, cond_embeds], dim=0) | |
| controlnet_image = preprocess_controlnet_image(control_image, device=device, dtype=weight_dtype) | |
| controlnet_image = torch.cat([controlnet_image, controlnet_image], dim=0) | |
| batch_size = 1 | |
| latent_shape = (batch_size, unet.in_channels, 64, 64) | |
| latents = torch.randn(latent_shape, device=device, dtype=weight_dtype) * noise_scheduler.init_noise_sigma | |
| noise_scheduler.set_timesteps(num_inference_steps) | |
| for t in noise_scheduler.timesteps: | |
| latent_input = torch.cat([latents] * 2) | |
| latent_input = noise_scheduler.scale_model_input(latent_input, t) | |
| down_block_res_samples, mid_block_res_sample = controlnet( | |
| latent_input, t, | |
| encoder_hidden_states=prompt_embeds, | |
| controlnet_cond=controlnet_image, | |
| return_dict=False | |
| ) | |
| noise_pred = unet( | |
| latent_input, t, | |
| encoder_hidden_states=prompt_embeds, | |
| down_block_additional_residuals=[res.to(dtype=weight_dtype) for res in down_block_res_samples], | |
| mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), | |
| return_dict=False, | |
| )[0] | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
| latents = 1 / 0.18215 * latents | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
| image = (image * 255).round().astype("uint8") | |
| return Image.fromarray(image) | |