smc_meissonic / src /smc /inference.py
cp524's picture
Increase step duration in SMC grad inference calculation from 5.0 to 6.0
8bfeb05
import os
import math
import threading
import spaces
from dataclasses import dataclass
import torch
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
)
from diffusers.models.autoencoders.vq_model import VQModel
from src.smc.transformer import Transformer2DModel
from src.smc.pipeline import Pipeline
from src.meissonic.scheduler import Scheduler
from src.smc.scheduler import ReMDMScheduler, MeissonicScheduler
import src.smc.rewards as rewards
from src.smc.resampling import resample
from PIL import Image
from typing import List
MIN_GPU_DURATION = 60
pipe_build_lock = threading.Lock()
reward_model_build_lock = threading.Lock()
device_load_lock = threading.Lock()
def build_pipe(device):
model_path = "Collov-Labs/Monetico"
dtype = torch.bfloat16
model = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", torch_dtype=dtype)
text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) # better for Monetico
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=dtype)
scheduler_new = MeissonicScheduler(
mask_token_id=scheduler.config.mask_token_id, # type: ignore
masking_schedule=scheduler.config.masking_schedule, # type: ignore
device=device,
)
pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new)
return pipe
def load_lora_weights(pipe, lora_ckpt_uuid):
# LORA lora checkpoint
ckpt_path = os.path.join('checkpoints', lora_ckpt_uuid)
pipe.load_lora_weights(
pretrained_model_name_or_path_or_dict=ckpt_path,
)
@dataclass
class InferenceOutput:
images: List[Image.Image]
image_rewards: List[float]
gpu_mem_used: float
@dataclass
class PretrainedInferenceConfig:
prompt: str
negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
resolution: int = 512
CFG: float = 9.0
steps: int = 48
num_batches: int = 4
def infer_pretrained(config: PretrainedInferenceConfig, device='cpu'):
with pipe_build_lock:
pipe = build_pipe(device)
return infer_pretrained_with_pipe(config, pipe, device=device)
def _get_pretrained_duration(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu') -> int:
setup_duration = 30.0
step_duration = 1.0
total_duration = math.ceil(setup_duration + step_duration * config.steps)
return max(total_duration, MIN_GPU_DURATION)
@spaces.GPU(duration=_get_pretrained_duration)
def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'):
if isinstance(device, str):
device = torch.device(device)
with device_load_lock:
pipe = pipe.to(device)
reward_bias = 5.0
with reward_model_build_lock:
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
image_reward_fn = lambda images: reward_fn(
images,
[config.prompt] * len(images)
)
images = pipe(
prompt=config.prompt,
reward_fn=image_reward_fn,
resample_fn=lambda log_w: resample(log_w),
negative_prompt=config.negative_prompt,
height=config.resolution,
width=config.resolution,
guidance_scale=config.CFG,
num_inference_steps=config.steps,
batches=config.num_batches,
num_particles=1,
batch_p=config.num_batches,
proposal_type="without_SMC",
output_type="pt",
)
image_rewards = (image_reward_fn(images) - reward_bias).tolist()
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
@dataclass
class SMCGradInferenceConfig:
prompt: str
negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
ess_threshold: float = 1.0
partial_resampling: bool = False
resample_frequency: int = 4
resolution: int = 512
CFG: float = 9.0
steps: int = 48
kl_weight: float = 0.02
lambda_tempering: bool = True
lambda_one_at: float = 0.3
num_batches: int = 4
num_particles: int = 8
proposal_type: str = "locally_optimal"
use_continuous_formulation: bool = True
phi: int = 1
tau: float = 1.0
def _get_batch_size_based_on_gpu_mem_smc_grad(device, phi):
if device.type == "cuda":
gpu_id = device.index if device.index is not None else 0
total_mem_gb = torch.cuda.get_device_properties(gpu_id).total_memory / 1024 ** 3
if phi == 1:
if total_mem_gb < 24:
batch_p = 1
elif total_mem_gb < 48:
batch_p = 4
elif total_mem_gb < 70:
batch_p = 7
else:
batch_p = 8
elif phi <= 4:
if total_mem_gb < 48:
batch_p = 1
elif total_mem_gb < 70:
batch_p = 2
else:
batch_p = 4
else:
batch_p = 1
else:
batch_p = 1
return batch_p
def infer_smc_grad(config: SMCGradInferenceConfig, device='cpu'):
with pipe_build_lock:
pipe = build_pipe(device)
return infer_smc_grad_with_pipe(config, pipe, device=device)
def _get_smc_grad_duration(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu') -> int:
setup_duration = 30.0
step_duration = 6.0
total_duration = math.ceil(setup_duration + step_duration * config.steps)
return max(total_duration, MIN_GPU_DURATION)
@spaces.GPU(duration=_get_smc_grad_duration)
def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'):
if isinstance(device, str):
device = torch.device(device)
with device_load_lock:
pipe = pipe.to(device)
reward_bias = 5.0
with reward_model_build_lock:
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
image_reward_fn = lambda images: reward_fn(
images,
[config.prompt] * len(images)
)
if config.lambda_tempering:
lambda_one_at = int(config.lambda_one_at * config.steps)
lambdas = torch.cat([torch.linspace(0, 1, lambda_one_at + 1), torch.ones(config.steps - lambda_one_at)])
else:
lambdas = None
batch_p = _get_batch_size_based_on_gpu_mem_smc_grad(device, config.phi)
images = pipe(
prompt=config.prompt,
reward_fn=image_reward_fn,
resample_fn=lambda log_w: resample(log_w, ess_threshold=config.ess_threshold, partial=config.partial_resampling),
resample_frequency=config.resample_frequency,
negative_prompt=config.negative_prompt,
height=config.resolution,
width=config.resolution,
guidance_scale=config.CFG,
num_inference_steps=config.steps,
kl_weight=config.kl_weight,
lambdas=lambdas,
batches=config.num_batches,
num_particles=config.num_particles,
batch_p = batch_p,
proposal_type="locally_optimal",
use_continuous_formulation=config.use_continuous_formulation,
phi=config.phi,
tau=config.tau,
output_type="pt",
)
image_rewards = (image_reward_fn(images) - reward_bias).tolist()
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
@dataclass
class FTInferenceConfig:
prompt: str
negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
resolution: int = 512
CFG: float = 9.0
steps: int = 48
num_batches: int = 4
ckpt_uuid: str = "a1e906e1-16a9-44a3-abe8-6dd2c17e12a2"
def infer_ft(config: FTInferenceConfig, device='cpu'):
with pipe_build_lock:
pipe = build_pipe(device)
return infer_ft_with_pipe(config, pipe, device=device)
def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') -> int:
setup_duration = 30.0
step_duration = 1.0
total_duration = math.ceil(setup_duration + step_duration * config.steps)
return max(total_duration, MIN_GPU_DURATION)
@spaces.GPU(duration=_get_ft_duration)
def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
if isinstance(device, str):
device = torch.device(device)
with device_load_lock:
pipe = pipe.to(device)
load_lora_weights(pipe, config.ckpt_uuid)
reward_bias = 5.0
with reward_model_build_lock:
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5"
image_reward_fn = lambda images: reward_fn(
images,
[config.prompt] * len(images)
)
images = pipe(
prompt=config.prompt,
reward_fn=image_reward_fn,
resample_fn=lambda log_w: resample(log_w),
negative_prompt=config.negative_prompt,
height=config.resolution,
width=config.resolution,
guidance_scale=config.CFG,
num_inference_steps=config.steps,
batches=config.num_batches,
num_particles=1,
batch_p=config.num_batches,
proposal_type="without_SMC",
output_type="pt",
)
image_rewards = (image_reward_fn(images) - reward_bias).tolist()
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)