Spaces:
Sleeping
Sleeping
File size: 10,552 Bytes
3e0672b b1beaa0 4991517 971c192 ed53392 b113524 b1beaa0 9712fd8 b49998b 9712fd8 4991517 b113524 4991517 b113524 3e0672b b113524 f36c67c b113524 9712fd8 b1beaa0 9712fd8 b1beaa0 9712fd8 b49998b 9712fd8 b1beaa0 b49998b b113524 b1beaa0 b113524 b1beaa0 b113524 f36c67c b113524 1af81da b113524 9712fd8 8bfeb05 9712fd8 b1beaa0 9712fd8 b113524 b49998b 9712fd8 b1beaa0 b49998b b113524 b1beaa0 b113524 b1beaa0 b113524 3e0672b b49998b 3e0672b b49998b 3e0672b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
import os
import math
import threading
import spaces
from dataclasses import dataclass
import torch
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
)
from diffusers.models.autoencoders.vq_model import VQModel
from src.smc.transformer import Transformer2DModel
from src.smc.pipeline import Pipeline
from src.meissonic.scheduler import Scheduler
from src.smc.scheduler import ReMDMScheduler, MeissonicScheduler
import src.smc.rewards as rewards
from src.smc.resampling import resample
from PIL import Image
from typing import List
MIN_GPU_DURATION = 60
pipe_build_lock = threading.Lock()
reward_model_build_lock = threading.Lock()
device_load_lock = threading.Lock()
def build_pipe(device):
model_path = "Collov-Labs/Monetico"
dtype = torch.bfloat16
model = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", torch_dtype=dtype)
text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) # better for Monetico
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=dtype)
scheduler_new = MeissonicScheduler(
mask_token_id=scheduler.config.mask_token_id, # type: ignore
masking_schedule=scheduler.config.masking_schedule, # type: ignore
device=device,
)
pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new)
return pipe
def load_lora_weights(pipe, lora_ckpt_uuid):
# LORA lora checkpoint
ckpt_path = os.path.join('checkpoints', lora_ckpt_uuid)
pipe.load_lora_weights(
pretrained_model_name_or_path_or_dict=ckpt_path,
)
@dataclass
class InferenceOutput:
images: List[Image.Image]
image_rewards: List[float]
gpu_mem_used: float
@dataclass
class PretrainedInferenceConfig:
prompt: str
negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
resolution: int = 512
CFG: float = 9.0
steps: int = 48
num_batches: int = 4
def infer_pretrained(config: PretrainedInferenceConfig, device='cpu'):
with pipe_build_lock:
pipe = build_pipe(device)
return infer_pretrained_with_pipe(config, pipe, device=device)
def _get_pretrained_duration(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu') -> int:
setup_duration = 30.0
step_duration = 1.0
total_duration = math.ceil(setup_duration + step_duration * config.steps)
return max(total_duration, MIN_GPU_DURATION)
@spaces.GPU(duration=_get_pretrained_duration)
def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'):
if isinstance(device, str):
device = torch.device(device)
with device_load_lock:
pipe = pipe.to(device)
reward_bias = 5.0
with reward_model_build_lock:
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
image_reward_fn = lambda images: reward_fn(
images,
[config.prompt] * len(images)
)
images = pipe(
prompt=config.prompt,
reward_fn=image_reward_fn,
resample_fn=lambda log_w: resample(log_w),
negative_prompt=config.negative_prompt,
height=config.resolution,
width=config.resolution,
guidance_scale=config.CFG,
num_inference_steps=config.steps,
batches=config.num_batches,
num_particles=1,
batch_p=config.num_batches,
proposal_type="without_SMC",
output_type="pt",
)
image_rewards = (image_reward_fn(images) - reward_bias).tolist()
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
@dataclass
class SMCGradInferenceConfig:
prompt: str
negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
ess_threshold: float = 1.0
partial_resampling: bool = False
resample_frequency: int = 4
resolution: int = 512
CFG: float = 9.0
steps: int = 48
kl_weight: float = 0.02
lambda_tempering: bool = True
lambda_one_at: float = 0.3
num_batches: int = 4
num_particles: int = 8
proposal_type: str = "locally_optimal"
use_continuous_formulation: bool = True
phi: int = 1
tau: float = 1.0
def _get_batch_size_based_on_gpu_mem_smc_grad(device, phi):
if device.type == "cuda":
gpu_id = device.index if device.index is not None else 0
total_mem_gb = torch.cuda.get_device_properties(gpu_id).total_memory / 1024 ** 3
if phi == 1:
if total_mem_gb < 24:
batch_p = 1
elif total_mem_gb < 48:
batch_p = 4
elif total_mem_gb < 70:
batch_p = 7
else:
batch_p = 8
elif phi <= 4:
if total_mem_gb < 48:
batch_p = 1
elif total_mem_gb < 70:
batch_p = 2
else:
batch_p = 4
else:
batch_p = 1
else:
batch_p = 1
return batch_p
def infer_smc_grad(config: SMCGradInferenceConfig, device='cpu'):
with pipe_build_lock:
pipe = build_pipe(device)
return infer_smc_grad_with_pipe(config, pipe, device=device)
def _get_smc_grad_duration(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu') -> int:
setup_duration = 30.0
step_duration = 6.0
total_duration = math.ceil(setup_duration + step_duration * config.steps)
return max(total_duration, MIN_GPU_DURATION)
@spaces.GPU(duration=_get_smc_grad_duration)
def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'):
if isinstance(device, str):
device = torch.device(device)
with device_load_lock:
pipe = pipe.to(device)
reward_bias = 5.0
with reward_model_build_lock:
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
image_reward_fn = lambda images: reward_fn(
images,
[config.prompt] * len(images)
)
if config.lambda_tempering:
lambda_one_at = int(config.lambda_one_at * config.steps)
lambdas = torch.cat([torch.linspace(0, 1, lambda_one_at + 1), torch.ones(config.steps - lambda_one_at)])
else:
lambdas = None
batch_p = _get_batch_size_based_on_gpu_mem_smc_grad(device, config.phi)
images = pipe(
prompt=config.prompt,
reward_fn=image_reward_fn,
resample_fn=lambda log_w: resample(log_w, ess_threshold=config.ess_threshold, partial=config.partial_resampling),
resample_frequency=config.resample_frequency,
negative_prompt=config.negative_prompt,
height=config.resolution,
width=config.resolution,
guidance_scale=config.CFG,
num_inference_steps=config.steps,
kl_weight=config.kl_weight,
lambdas=lambdas,
batches=config.num_batches,
num_particles=config.num_particles,
batch_p = batch_p,
proposal_type="locally_optimal",
use_continuous_formulation=config.use_continuous_formulation,
phi=config.phi,
tau=config.tau,
output_type="pt",
)
image_rewards = (image_reward_fn(images) - reward_bias).tolist()
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
@dataclass
class FTInferenceConfig:
prompt: str
negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
resolution: int = 512
CFG: float = 9.0
steps: int = 48
num_batches: int = 4
ckpt_uuid: str = "a1e906e1-16a9-44a3-abe8-6dd2c17e12a2"
def infer_ft(config: FTInferenceConfig, device='cpu'):
with pipe_build_lock:
pipe = build_pipe(device)
return infer_ft_with_pipe(config, pipe, device=device)
def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') -> int:
setup_duration = 30.0
step_duration = 1.0
total_duration = math.ceil(setup_duration + step_duration * config.steps)
return max(total_duration, MIN_GPU_DURATION)
@spaces.GPU(duration=_get_ft_duration)
def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
if isinstance(device, str):
device = torch.device(device)
with device_load_lock:
pipe = pipe.to(device)
load_lora_weights(pipe, config.ckpt_uuid)
reward_bias = 5.0
with reward_model_build_lock:
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
image_reward_fn = lambda images: reward_fn(
images,
[config.prompt] * len(images)
)
images = pipe(
prompt=config.prompt,
reward_fn=image_reward_fn,
resample_fn=lambda log_w: resample(log_w),
negative_prompt=config.negative_prompt,
height=config.resolution,
width=config.resolution,
guidance_scale=config.CFG,
num_inference_steps=config.steps,
batches=config.num_batches,
num_particles=1,
batch_p=config.num_batches,
proposal_type="without_SMC",
output_type="pt",
)
image_rewards = (image_reward_fn(images) - reward_bias).tolist()
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
|