Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,9 +30,9 @@ transformer_trainable_parameters = None
|
|
| 30 |
|
| 31 |
def load_lora_from_subfolder():
|
| 32 |
repo_id = "benzweijia/Adv-GRPO"
|
| 33 |
-
subfolder = "
|
| 34 |
|
| 35 |
-
local_dir = "/tmp/
|
| 36 |
os.makedirs(local_dir, exist_ok=True)
|
| 37 |
|
| 38 |
for filename in ["adapter_config.json", "adapter_model.safetensors"]:
|
|
@@ -107,7 +107,7 @@ def init_model():
|
|
| 107 |
pipeline.text_encoder_2.to("cuda")
|
| 108 |
pipeline.text_encoder_3.to("cuda")
|
| 109 |
pipeline.transformer.to("cuda")
|
| 110 |
-
config.train.lora_path = "benzweijia/Adv-GRPO/
|
| 111 |
config.use_lora = True
|
| 112 |
lora_dir = load_lora_from_subfolder()
|
| 113 |
|
|
@@ -115,7 +115,7 @@ def init_model():
|
|
| 115 |
print("🔥 Loading LoRA from:", config.train.lora_path)
|
| 116 |
pipeline.transformer = PeftModel.from_pretrained(
|
| 117 |
pipeline.transformer,
|
| 118 |
-
os.path.join(lora_dir,"
|
| 119 |
)
|
| 120 |
pipeline.transformer.set_adapter("default")
|
| 121 |
|
|
@@ -146,7 +146,6 @@ def infer(prompt):
|
|
| 146 |
if pipeline is None:
|
| 147 |
init_model()
|
| 148 |
print(pipeline)
|
| 149 |
-
print("start infer 1111")
|
| 150 |
|
| 151 |
|
| 152 |
prompts = [prompt]
|
|
@@ -157,7 +156,6 @@ def infer(prompt):
|
|
| 157 |
max_sequence_length=128,
|
| 158 |
device="cuda"
|
| 159 |
)
|
| 160 |
-
print("start infer 2")
|
| 161 |
|
| 162 |
neg_embed, neg_pooled_embed = compute_text_embeddings(
|
| 163 |
[""], text_encoders, tokenizers,
|
|
@@ -167,7 +165,6 @@ def infer(prompt):
|
|
| 167 |
|
| 168 |
neg_prompt_embeds = neg_embed.repeat(1, 1, 1)
|
| 169 |
neg_pooled_prompt_embeds = neg_pooled_embed.repeat(1, 1)
|
| 170 |
-
print("start infer 3")
|
| 171 |
|
| 172 |
# generation seed
|
| 173 |
generator = torch.Generator().manual_seed(0)
|
|
@@ -192,9 +189,6 @@ def infer(prompt):
|
|
| 192 |
generator=generator,
|
| 193 |
)
|
| 194 |
|
| 195 |
-
print("images type:", type(images))
|
| 196 |
-
print("images len:", len(images))
|
| 197 |
-
print("first image shape:", images[0].shape)
|
| 198 |
|
| 199 |
# Convert to PIL
|
| 200 |
pil = Image.fromarray(
|
|
@@ -216,7 +210,7 @@ demo = gr.Interface(
|
|
| 216 |
fn=infer,
|
| 217 |
inputs=gr.Textbox(lines=2, label="Prompt"),
|
| 218 |
outputs=gr.Image(type="pil"),
|
| 219 |
-
title="Adv-GRPO(
|
| 220 |
description="Enter a prompt and generate image using Adv-GRPO",
|
| 221 |
)
|
| 222 |
|
|
|
|
| 30 |
|
| 31 |
def load_lora_from_subfolder():
|
| 32 |
repo_id = "benzweijia/Adv-GRPO"
|
| 33 |
+
subfolder = "DINO"
|
| 34 |
|
| 35 |
+
local_dir = "/tmp/DINO"
|
| 36 |
os.makedirs(local_dir, exist_ok=True)
|
| 37 |
|
| 38 |
for filename in ["adapter_config.json", "adapter_model.safetensors"]:
|
|
|
|
| 107 |
pipeline.text_encoder_2.to("cuda")
|
| 108 |
pipeline.text_encoder_3.to("cuda")
|
| 109 |
pipeline.transformer.to("cuda")
|
| 110 |
+
config.train.lora_path = "benzweijia/Adv-GRPO/DINO"
|
| 111 |
config.use_lora = True
|
| 112 |
lora_dir = load_lora_from_subfolder()
|
| 113 |
|
|
|
|
| 115 |
print("🔥 Loading LoRA from:", config.train.lora_path)
|
| 116 |
pipeline.transformer = PeftModel.from_pretrained(
|
| 117 |
pipeline.transformer,
|
| 118 |
+
os.path.join(lora_dir,"DINO")
|
| 119 |
)
|
| 120 |
pipeline.transformer.set_adapter("default")
|
| 121 |
|
|
|
|
| 146 |
if pipeline is None:
|
| 147 |
init_model()
|
| 148 |
print(pipeline)
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
prompts = [prompt]
|
|
|
|
| 156 |
max_sequence_length=128,
|
| 157 |
device="cuda"
|
| 158 |
)
|
|
|
|
| 159 |
|
| 160 |
neg_embed, neg_pooled_embed = compute_text_embeddings(
|
| 161 |
[""], text_encoders, tokenizers,
|
|
|
|
| 165 |
|
| 166 |
neg_prompt_embeds = neg_embed.repeat(1, 1, 1)
|
| 167 |
neg_pooled_prompt_embeds = neg_pooled_embed.repeat(1, 1)
|
|
|
|
| 168 |
|
| 169 |
# generation seed
|
| 170 |
generator = torch.Generator().manual_seed(0)
|
|
|
|
| 189 |
generator=generator,
|
| 190 |
)
|
| 191 |
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
# Convert to PIL
|
| 194 |
pil = Image.fromarray(
|
|
|
|
| 210 |
fn=infer,
|
| 211 |
inputs=gr.Textbox(lines=2, label="Prompt"),
|
| 212 |
outputs=gr.Image(type="pil"),
|
| 213 |
+
title="Adv-GRPO(DINO)",
|
| 214 |
description="Enter a prompt and generate image using Adv-GRPO",
|
| 215 |
)
|
| 216 |
|