Adv-GRPO_DINO / app.py
benzweijia's picture
Update app.py
7fafe48 verified
import gradio as gr
import spaces
import torch
from diffusers import StableDiffusion3Pipeline
from adv_grpo.diffusers_patch.sd3_pipeline_with_logprob_fast import pipeline_with_logprob_random as pipeline_with_logprob
from adv_grpo.diffusers_patch.train_dreambooth_lora_sd3 import encode_prompt
from adv_grpo.ema import EMAModuleWrapper
from peft import PeftModel
from PIL import Image
import numpy as np
import os
from ml_collections import config_flags
from huggingface_hub import hf_hub_download
from huggingface_hub import login
login(os.environ["HF_TOKEN"])
# ---------------------------------------------------------
# GLOBAL VARIABLES
# ---------------------------------------------------------
pipeline = None
config = None
text_encoders = None
tokenizers = None
ema = None
transformer_trainable_parameters = None
def load_lora_from_subfolder():
repo_id = "benzweijia/Adv-GRPO"
subfolder = "DINO"
local_dir = "/tmp/DINO"
os.makedirs(local_dir, exist_ok=True)
for filename in ["adapter_config.json", "adapter_model.safetensors"]:
hf_hub_download(
repo_id=repo_id,
repo_type="model",
subfolder=subfolder,
filename=filename,
local_dir=local_dir,
force_download=False
)
# import pdb; pdb.set_trace()
return local_dir
# -------------- Load Config ------------------------------
def load_config():
"""
"""
import importlib.util
config_path = "config/base.py"
spec = importlib.util.spec_from_file_location("config", config_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.get_config()
# -------------- Embedding Function -----------------------
def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, max_sequence_length
)
prompt_embeds = prompt_embeds.to(device)
pooled_prompt_embeds = pooled_prompt_embeds.to(device)
return prompt_embeds, pooled_prompt_embeds
# ---------------------------------------------------------
# GPU MODEL INITIALIZATION
# ---------------------------------------------------------
@spaces.GPU
def init_model():
global pipeline, config, text_encoders, tokenizers, ema, transformer_trainable_parameters
print("πŸ”₯ Loading config...")
config = load_config()
print("πŸ”₯ Loading SD3 base model on GPU...")
# import pdb; pdb.set_trace()
pipeline = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium"
)
# freeze non-trainable params
pipeline.vae.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)
pipeline.text_encoder_2.requires_grad_(False)
pipeline.text_encoder_3.requires_grad_(False)
pipeline.transformer.requires_grad_(not config.use_lora)
text_encoders = [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.text_encoder_3]
tokenizers = [pipeline.tokenizer, pipeline.tokenizer_2, pipeline.tokenizer_3]
pipeline.safety_checker = None
pipeline.set_progress_bar_config(disable=True)
# move to GPU
pipeline.vae.to("cuda")
pipeline.text_encoder.to("cuda")
pipeline.text_encoder_2.to("cuda")
pipeline.text_encoder_3.to("cuda")
pipeline.transformer.to("cuda")
config.train.lora_path = "benzweijia/Adv-GRPO/DINO"
config.use_lora = True
lora_dir = load_lora_from_subfolder()
if config.use_lora and config.train.lora_path:
print("πŸ”₯ Loading LoRA from:", config.train.lora_path)
pipeline.transformer = PeftModel.from_pretrained(
pipeline.transformer,
os.path.join(lora_dir,"DINO")
)
pipeline.transformer.set_adapter("default")
transformer_trainable_parameters = list(
filter(lambda p: p.requires_grad, pipeline.transformer.parameters())
)
# Setup EMA
ema = EMAModuleWrapper(
transformer_trainable_parameters,
decay=0.9,
update_step_interval=8,
device="cuda"
)
print("βœ… Model initialized and ready.")
# ---------------------------------------------------------
# INFERENCE FUNCTION
# ---------------------------------------------------------
@spaces.GPU
def infer(prompt):
print("start infer")
global pipeline, config
print(pipeline)
if pipeline is None:
init_model()
print(pipeline)
prompts = [prompt]
# get prompt embedding
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers,
max_sequence_length=128,
device="cuda"
)
neg_embed, neg_pooled_embed = compute_text_embeddings(
[""], text_encoders, tokenizers,
max_sequence_length=128,
device="cuda"
)
neg_prompt_embeds = neg_embed.repeat(1, 1, 1)
neg_pooled_prompt_embeds = neg_pooled_embed.repeat(1, 1)
# generation seed
generator = torch.Generator().manual_seed(0)
with torch.no_grad():
images, _, _, _ = pipeline_with_logprob(
pipeline,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_prompt_embeds=neg_prompt_embeds,
negative_pooled_prompt_embeds=neg_pooled_prompt_embeds,
num_inference_steps=config.sample.eval_num_steps,
guidance_scale=config.sample.guidance_scale,
output_type="pt",
height=config.resolution,
width=config.resolution,
noise_level=0,
mini_num_image_per_prompt=1,
process_index=0,
sample_num_steps=config.sample.num_steps,
random_timestep=0,
generator=generator,
)
# Convert to PIL
pil = Image.fromarray(
(images[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
)
# Fixed 512x512 for output
pil = pil.resize((512, 512))
return pil
# ---------------------------------------------------------
# GRADIO UI
# ---------------------------------------------------------
# init_model()
demo = gr.Interface(
fn=infer,
inputs=gr.Textbox(lines=2, label="Prompt"),
outputs=gr.Image(type="pil"),
title="Adv-GRPO(DINO)",
description="Enter a prompt and generate image using Adv-GRPO",
)
demo.launch()