Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| if 'SPACES_APP' in os.environ: | |
| os.system("pip install flash-attn==2.7.3 --no-build-isolation") | |
| import sys | |
| import torch | |
| import diffusers | |
| import transformers | |
| import argparse | |
| import peft | |
| import copy | |
| import cv2 | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from peft import LoraConfig | |
| from omegaconf import OmegaConf | |
| from safetensors.torch import safe_open | |
| from PIL import Image, ImageDraw, ImageFilter | |
| from huggingface_hub import hf_hub_download | |
| from transformers import pipeline | |
| from models import HunyuanVideoTransformer3DModel | |
| from pipelines import HunyuanVideoImageToVideoPipeline | |
| header = """ | |
| # DRA-Ctrl Gradio App | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://arxiv.org/pdf/2505.23325"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://arxiv.org/abs/2505.23325"><img src="https://img.shields.io/badge/ariXv-Page-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://huggingface.co/Kunbyte/DRA-Ctrl"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> | |
| <a href="https://huggingface.co/spaces/Kunbyte/DRA-Ctrl"><img src="https://img.shields.io/badge/🤗-Space-ffbd45.svg" alt="HuggingFace"></a> | |
| <a href="https://github.com/Kunbyte-AI/DRA-Ctrl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
| <a href="https://dra-ctrl-2025.github.io/DRA-Ctrl/"><img src="https://img.shields.io/badge/Project-Page-blue" alt="Project"></a> | |
| </div> | |
| """ | |
| notice = """ | |
| For easier testing, in spatially-aligned image generation tasks, when passing the condition image to `gradio_app`, | |
| there's no need to manually input edge maps, depth maps, or other condition images - only the original image is required. | |
| The corresponding condition images will be automatically extracted. | |
| """ | |
| def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task): | |
| # init models | |
| transformer = HunyuanVideoTransformer3DModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
| subfolder="transformer", | |
| inference_subject_driven=task in ['subject_driven']) | |
| scheduler = diffusers.FlowMatchEulerDiscreteScheduler() | |
| vae = diffusers.AutoencoderKLHunyuanVideo.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
| subfolder="vae") | |
| text_encoder = transformers.LlavaForConditionalGeneration.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
| subfolder="text_encoder") | |
| text_encoder_2 = transformers.CLIPTextModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
| subfolder="text_encoder_2") | |
| tokenizer = transformers.AutoTokenizer.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
| subfolder="tokenizer") | |
| tokenizer_2 = transformers.CLIPTokenizer.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
| subfolder="tokenizer_2") | |
| image_processor = transformers.CLIPImageProcessor.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
| subfolder="image_processor") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weight_dtype = torch.bfloat16 | |
| transformer.requires_grad_(False) | |
| vae.requires_grad_(False).to(device, dtype=weight_dtype) | |
| text_encoder.requires_grad_(False).to(device, dtype=weight_dtype) | |
| text_encoder_2.requires_grad_(False).to(device, dtype=weight_dtype) | |
| transformer.to(device, dtype=weight_dtype) | |
| vae.enable_tiling() | |
| vae.enable_slicing() | |
| # insert LoRA | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=16, | |
| init_lora_weights="gaussian", | |
| target_modules=[ | |
| 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0', | |
| 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out', | |
| 'ff.net.0.proj', 'ff.net.2', | |
| 'ff_context.net.0.proj', 'ff_context.net.2', | |
| 'norm1_context.linear', 'norm1.linear', | |
| 'norm.linear', 'proj_mlp', 'proj_out', | |
| ] | |
| ) | |
| transformer.add_adapter(lora_config) | |
| # hack LoRA forward | |
| def create_hacked_forward(module): | |
| lora_forward = module.forward | |
| non_lora_forward = module.base_layer.forward | |
| img_sequence_length = int((args.img_size / 8 / 2) ** 2) | |
| encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt | |
| num_imgs = 4 | |
| num_generated_imgs = 3 | |
| num_encoder_sequences = 2 if args.task in ['subject_driven', 'style_transfer'] else 1 | |
| def hacked_lora_forward(self, x, *args, **kwargs): | |
| if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2: | |
| return torch.cat(( | |
| lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs), | |
| non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs) | |
| ), dim=1) | |
| elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length: | |
| return lora_forward(x, *args, **kwargs) | |
| elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences: | |
| return torch.cat(( | |
| lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs), | |
| non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs), | |
| lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs) | |
| ), dim=1) | |
| elif x.shape[1] == 3072: | |
| return non_lora_forward(x, *args, **kwargs) | |
| else: | |
| raise ValueError( | |
| f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!" | |
| ) | |
| return hacked_lora_forward.__get__(module, type(module)) | |
| for n, m in transformer.named_modules(): | |
| if isinstance(m, peft.tuners.lora.layer.Linear): | |
| m.forward = create_hacked_forward(m) | |
| # load LoRA weights | |
| model_root = hf_hub_download( | |
| repo_id="Kunbyte/DRA-Ctrl", | |
| filename=f"{task}.safetensors", | |
| resume_download=True) | |
| try: | |
| with safe_open(model_root, framework="pt") as f: | |
| lora_weights = {} | |
| for k in f.keys(): | |
| param = f.get_tensor(k) | |
| if k.endswith(".weight"): | |
| k = k.replace('.weight', '.default.weight') | |
| lora_weights[k] = param | |
| transformer.load_state_dict(lora_weights, strict=False) | |
| except Exception as e: | |
| raise ValueError(f'{e}') | |
| transformer.requires_grad_(False) | |
| pipe = HunyuanVideoImageToVideoPipeline( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| transformer=transformer, | |
| vae=vae, | |
| scheduler=copy.deepcopy(scheduler), | |
| text_encoder_2=text_encoder_2, | |
| tokenizer_2=tokenizer_2, | |
| image_processor=image_processor, | |
| ) | |
| # start generation | |
| c_txt = None if condition_image_prompt == "" else condition_image_prompt | |
| c_img = condition_image.resize((512, 512)) | |
| t_txt = target_prompt | |
| if args.task not in ['subject_driven', 'style_transfer']: | |
| if args.task == "canny": | |
| def get_canny_edge(img): | |
| img_np = np.array(img) | |
| img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) | |
| edges = cv2.Canny(img_gray, 100, 200) | |
| edges_tmp = Image.fromarray(edges).convert("RGB") | |
| edges_tmp.save(os.path.join(save_dir, f"edges.png")) | |
| edges[edges == 0] = 128 | |
| return Image.fromarray(edges).convert("RGB") | |
| c_img = get_canny_edge(c_img) | |
| elif args.task == "coloring": | |
| c_img = ( | |
| c_img.resize((args.img_size, args.img_size)) | |
| .convert("L") | |
| .convert("RGB") | |
| ) | |
| elif args.task == "deblurring": | |
| blur_radius = 10 | |
| c_img = ( | |
| c_img.convert("RGB") | |
| .filter(ImageFilter.GaussianBlur(blur_radius)) | |
| .resize((args.img_size, args.img_size)) | |
| .convert("RGB") | |
| ) | |
| elif args.task == "depth": | |
| def get_depth_map(img): | |
| from transformers import pipeline | |
| depth_pipe = pipeline( | |
| task="depth-estimation", | |
| model="LiheYoung/depth-anything-small-hf", | |
| device="cpu", | |
| ) | |
| return depth_pipe(img)["depth"].convert("RGB").resize((args.img_size, args.img_size)) | |
| c_img = get_depth_map(c_img) | |
| c_img.save(os.path.join(save_dir, f"depth.png")) | |
| k = (255 - 128) / 255 | |
| b = 128 | |
| c_img = c_img.point(lambda x: k * x + b) | |
| elif args.task == "depth_pred": | |
| c_img = c_img | |
| elif args.task == "fill": | |
| c_img = c_img.resize((args.img_size, args.img_size)).convert("RGB") | |
| x1, x2 = args.fill_x1, args.fill_x2 | |
| y1, y2 = args.fill_y1, args.fill_y2 | |
| mask = Image.new("L", (args.img_size, args.img_size), 0) | |
| draw = ImageDraw.Draw(mask) | |
| draw.rectangle((x1, y1, x2, y2), fill=255) | |
| if args.inpainting: | |
| mask = Image.eval(mask, lambda a: 255 - a) | |
| c_img = Image.composite( | |
| c_img, | |
| Image.new("RGB", (args.img_size, args.img_size), (255, 255, 255)), | |
| mask | |
| ) | |
| c_img.save(os.path.join(save_dir, f"mask.png")) | |
| c_img = Image.composite( | |
| c_img, | |
| Image.new("RGB", (args.img_size, args.img_size), (128, 128, 128)), | |
| mask | |
| ) | |
| elif args.task == "sr": | |
| c_img = c_img.resize((int(args.img_size / 4), int(args.img_size / 4))).convert("RGB") | |
| c_img.save(os.path.join(save_dir, f"low_resolution.png")) | |
| c_img = c_img.resize((args.img_size, args.img_size)) | |
| c_img.save(os.path.join(save_dir, f"low_to_high.png")) | |
| gen_img = pipe( | |
| image=c_img, | |
| prompt=[t_txt.strip()], | |
| prompt_condition=[c_txt.strip()] if c_txt is not None else None, | |
| prompt_2=[t_txt], | |
| height=512, | |
| width=512, | |
| num_frames=5, | |
| num_inference_steps=50, | |
| guidance_scale=6.0, | |
| num_videos_per_prompt=1, | |
| generator=torch.Generator(device=pipe.transformer.device).manual_seed(0), | |
| output_type='pt', | |
| image_embed_interleave=4, | |
| frame_gap=48, | |
| mixup=True, | |
| mixup_num_imgs=2, | |
| ).frames | |
| gen_img = gen_img[:, 0:1, :, :, :] | |
| gen_img = gen_img.squeeze(0).squeeze(0).cpu().to(torch.float32).numpy() | |
| gen_img = np.transpose(gen_img, (1, 2, 0)) | |
| gen_img = (gen_img * 255).astype(np.uint8) | |
| gen_img = Image.fromarray(gen_img) | |
| return gen_img | |
| def create_app(): | |
| with gr.Blocks() as app: | |
| gr.Markdown(header, elem_id="header") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(variant="panel", elem_classes="inputPanel"): | |
| condition_image = gr.Image( | |
| type="pil", label="Condition Image", width=300, elem_id="input" | |
| ) | |
| task = gr.Radio( | |
| [ | |
| ("Subject-driven Image Generation", "subject_driven"), | |
| ("Canny-to-Image", "canny"), | |
| ("Colorization", "coloring"), | |
| ("Deblurring", "deblurring"), | |
| ("Depth-to-Image", "depth"), | |
| ("Depth Prediction", "depth_pred"), | |
| ("In/Out-Painting", "fill"), | |
| ("Super-Resolution", "sr"), | |
| ("Style Transfer", "style_transfer") | |
| ], | |
| label="Task Selection", | |
| value="subject_driven", | |
| interactive=True, | |
| elem_id="task_selection" | |
| ) | |
| gr.Markdown(notice, elem_id="notice") | |
| target_prompt = gr.Textbox(lines=2, label="Target Prompt", elem_id="text") | |
| condition_image_prompt = gr.Textbox(lines=2, label="Condition Image Prompt", elem_id="text") | |
| submit_btn = gr.Button("Run", elem_id="submit_btn") | |
| with gr.Column(variant="panel", elem_classes="outputPanel"): | |
| output_image = gr.Image(type="pil", elem_id="output") | |
| submit_btn.click( | |
| fn=process_image_and_text, | |
| inputs=[condition_image, target_prompt, condition_image_prompt, task], | |
| outputs=output_image, | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| create_app().launch(debug=True, ssr_mode=False) | |