Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from diffusers import StableDiffusion3Pipeline | |
| from adv_grpo.diffusers_patch.sd3_pipeline_with_logprob_fast import pipeline_with_logprob_random as pipeline_with_logprob | |
| from adv_grpo.diffusers_patch.train_dreambooth_lora_sd3 import encode_prompt | |
| from adv_grpo.ema import EMAModuleWrapper | |
| from peft import PeftModel | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| from ml_collections import config_flags | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub import login | |
| login(os.environ["HF_TOKEN"]) | |
| # --------------------------------------------------------- | |
| # GLOBAL VARIABLES | |
| # --------------------------------------------------------- | |
| pipeline = None | |
| config = None | |
| text_encoders = None | |
| tokenizers = None | |
| ema = None | |
| transformer_trainable_parameters = None | |
| def load_lora_from_subfolder(): | |
| repo_id = "benzweijia/Adv-GRPO" | |
| subfolder = "DINO" | |
| local_dir = "/tmp/DINO" | |
| os.makedirs(local_dir, exist_ok=True) | |
| for filename in ["adapter_config.json", "adapter_model.safetensors"]: | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| subfolder=subfolder, | |
| filename=filename, | |
| local_dir=local_dir, | |
| force_download=False | |
| ) | |
| # import pdb; pdb.set_trace() | |
| return local_dir | |
| # -------------- Load Config ------------------------------ | |
| def load_config(): | |
| """ | |
| """ | |
| import importlib.util | |
| config_path = "config/base.py" | |
| spec = importlib.util.spec_from_file_location("config", config_path) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| return module.get_config() | |
| # -------------- Embedding Function ----------------------- | |
| def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device): | |
| with torch.no_grad(): | |
| prompt_embeds, pooled_prompt_embeds = encode_prompt( | |
| text_encoders, tokenizers, prompt, max_sequence_length | |
| ) | |
| prompt_embeds = prompt_embeds.to(device) | |
| pooled_prompt_embeds = pooled_prompt_embeds.to(device) | |
| return prompt_embeds, pooled_prompt_embeds | |
| # --------------------------------------------------------- | |
| # GPU MODEL INITIALIZATION | |
| # --------------------------------------------------------- | |
| def init_model(): | |
| global pipeline, config, text_encoders, tokenizers, ema, transformer_trainable_parameters | |
| print("π₯ Loading config...") | |
| config = load_config() | |
| print("π₯ Loading SD3 base model on GPU...") | |
| # import pdb; pdb.set_trace() | |
| pipeline = StableDiffusion3Pipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-3.5-medium" | |
| ) | |
| # freeze non-trainable params | |
| pipeline.vae.requires_grad_(False) | |
| pipeline.text_encoder.requires_grad_(False) | |
| pipeline.text_encoder_2.requires_grad_(False) | |
| pipeline.text_encoder_3.requires_grad_(False) | |
| pipeline.transformer.requires_grad_(not config.use_lora) | |
| text_encoders = [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.text_encoder_3] | |
| tokenizers = [pipeline.tokenizer, pipeline.tokenizer_2, pipeline.tokenizer_3] | |
| pipeline.safety_checker = None | |
| pipeline.set_progress_bar_config(disable=True) | |
| # move to GPU | |
| pipeline.vae.to("cuda") | |
| pipeline.text_encoder.to("cuda") | |
| pipeline.text_encoder_2.to("cuda") | |
| pipeline.text_encoder_3.to("cuda") | |
| pipeline.transformer.to("cuda") | |
| config.train.lora_path = "benzweijia/Adv-GRPO/DINO" | |
| config.use_lora = True | |
| lora_dir = load_lora_from_subfolder() | |
| if config.use_lora and config.train.lora_path: | |
| print("π₯ Loading LoRA from:", config.train.lora_path) | |
| pipeline.transformer = PeftModel.from_pretrained( | |
| pipeline.transformer, | |
| os.path.join(lora_dir,"DINO") | |
| ) | |
| pipeline.transformer.set_adapter("default") | |
| transformer_trainable_parameters = list( | |
| filter(lambda p: p.requires_grad, pipeline.transformer.parameters()) | |
| ) | |
| # Setup EMA | |
| ema = EMAModuleWrapper( | |
| transformer_trainable_parameters, | |
| decay=0.9, | |
| update_step_interval=8, | |
| device="cuda" | |
| ) | |
| print("β Model initialized and ready.") | |
| # --------------------------------------------------------- | |
| # INFERENCE FUNCTION | |
| # --------------------------------------------------------- | |
| def infer(prompt): | |
| print("start infer") | |
| global pipeline, config | |
| print(pipeline) | |
| if pipeline is None: | |
| init_model() | |
| print(pipeline) | |
| prompts = [prompt] | |
| # get prompt embedding | |
| prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( | |
| prompts, text_encoders, tokenizers, | |
| max_sequence_length=128, | |
| device="cuda" | |
| ) | |
| neg_embed, neg_pooled_embed = compute_text_embeddings( | |
| [""], text_encoders, tokenizers, | |
| max_sequence_length=128, | |
| device="cuda" | |
| ) | |
| neg_prompt_embeds = neg_embed.repeat(1, 1, 1) | |
| neg_pooled_prompt_embeds = neg_pooled_embed.repeat(1, 1) | |
| # generation seed | |
| generator = torch.Generator().manual_seed(0) | |
| with torch.no_grad(): | |
| images, _, _, _ = pipeline_with_logprob( | |
| pipeline, | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_prompt_embeds=neg_prompt_embeds, | |
| negative_pooled_prompt_embeds=neg_pooled_prompt_embeds, | |
| num_inference_steps=config.sample.eval_num_steps, | |
| guidance_scale=config.sample.guidance_scale, | |
| output_type="pt", | |
| height=config.resolution, | |
| width=config.resolution, | |
| noise_level=0, | |
| mini_num_image_per_prompt=1, | |
| process_index=0, | |
| sample_num_steps=config.sample.num_steps, | |
| random_timestep=0, | |
| generator=generator, | |
| ) | |
| # Convert to PIL | |
| pil = Image.fromarray( | |
| (images[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| ) | |
| # Fixed 512x512 for output | |
| pil = pil.resize((512, 512)) | |
| return pil | |
| # --------------------------------------------------------- | |
| # GRADIO UI | |
| # --------------------------------------------------------- | |
| # init_model() | |
| demo = gr.Interface( | |
| fn=infer, | |
| inputs=gr.Textbox(lines=2, label="Prompt"), | |
| outputs=gr.Image(type="pil"), | |
| title="Adv-GRPO(DINO)", | |
| description="Enter a prompt and generate image using Adv-GRPO", | |
| ) | |
| demo.launch() | |