import os import math import threading import spaces from dataclasses import dataclass import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, ) from diffusers.models.autoencoders.vq_model import VQModel from src.smc.transformer import Transformer2DModel from src.smc.pipeline import Pipeline from src.meissonic.scheduler import Scheduler from src.smc.scheduler import ReMDMScheduler, MeissonicScheduler import src.smc.rewards as rewards from src.smc.resampling import resample from PIL import Image from typing import List MIN_GPU_DURATION = 60 pipe_build_lock = threading.Lock() reward_model_build_lock = threading.Lock() device_load_lock = threading.Lock() def build_pipe(device): model_path = "Collov-Labs/Monetico" dtype = torch.bfloat16 model = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype) vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", torch_dtype=dtype) text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) # better for Monetico tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype) scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=dtype) scheduler_new = MeissonicScheduler( mask_token_id=scheduler.config.mask_token_id, # type: ignore masking_schedule=scheduler.config.masking_schedule, # type: ignore device=device, ) pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new) return pipe def load_lora_weights(pipe, lora_ckpt_uuid): # LORA lora checkpoint ckpt_path = os.path.join('checkpoints', lora_ckpt_uuid) pipe.load_lora_weights( pretrained_model_name_or_path_or_dict=ckpt_path, ) @dataclass class InferenceOutput: images: List[Image.Image] image_rewards: List[float] gpu_mem_used: float @dataclass class PretrainedInferenceConfig: prompt: str negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" resolution: int = 512 CFG: float = 9.0 steps: int = 48 num_batches: int = 4 def infer_pretrained(config: PretrainedInferenceConfig, device='cpu'): with pipe_build_lock: pipe = build_pipe(device) return infer_pretrained_with_pipe(config, pipe, device=device) def _get_pretrained_duration(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu') -> int: setup_duration = 30.0 step_duration = 1.0 total_duration = math.ceil(setup_duration + step_duration * config.steps) return max(total_duration, MIN_GPU_DURATION) @spaces.GPU(duration=_get_pretrained_duration) def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'): if isinstance(device, str): device = torch.device(device) with device_load_lock: pipe = pipe.to(device) reward_bias = 5.0 with reward_model_build_lock: reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5" image_reward_fn = lambda images: reward_fn( images, [config.prompt] * len(images) ) images = pipe( prompt=config.prompt, reward_fn=image_reward_fn, resample_fn=lambda log_w: resample(log_w), negative_prompt=config.negative_prompt, height=config.resolution, width=config.resolution, guidance_scale=config.CFG, num_inference_steps=config.steps, batches=config.num_batches, num_particles=1, batch_p=config.num_batches, proposal_type="without_SMC", output_type="pt", ) image_rewards = (image_reward_fn(images) - reward_bias).tolist() pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3 return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used) @dataclass class SMCGradInferenceConfig: prompt: str negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" ess_threshold: float = 1.0 partial_resampling: bool = False resample_frequency: int = 4 resolution: int = 512 CFG: float = 9.0 steps: int = 48 kl_weight: float = 0.02 lambda_tempering: bool = True lambda_one_at: float = 0.3 num_batches: int = 4 num_particles: int = 8 proposal_type: str = "locally_optimal" use_continuous_formulation: bool = True phi: int = 1 tau: float = 1.0 def _get_batch_size_based_on_gpu_mem_smc_grad(device, phi): if device.type == "cuda": gpu_id = device.index if device.index is not None else 0 total_mem_gb = torch.cuda.get_device_properties(gpu_id).total_memory / 1024 ** 3 if phi == 1: if total_mem_gb < 24: batch_p = 1 elif total_mem_gb < 48: batch_p = 4 elif total_mem_gb < 70: batch_p = 7 else: batch_p = 8 elif phi <= 4: if total_mem_gb < 48: batch_p = 1 elif total_mem_gb < 70: batch_p = 2 else: batch_p = 4 else: batch_p = 1 else: batch_p = 1 return batch_p def infer_smc_grad(config: SMCGradInferenceConfig, device='cpu'): with pipe_build_lock: pipe = build_pipe(device) return infer_smc_grad_with_pipe(config, pipe, device=device) def _get_smc_grad_duration(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu') -> int: setup_duration = 30.0 step_duration = 6.0 total_duration = math.ceil(setup_duration + step_duration * config.steps) return max(total_duration, MIN_GPU_DURATION) @spaces.GPU(duration=_get_smc_grad_duration) def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'): if isinstance(device, str): device = torch.device(device) with device_load_lock: pipe = pipe.to(device) reward_bias = 5.0 with reward_model_build_lock: reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5" image_reward_fn = lambda images: reward_fn( images, [config.prompt] * len(images) ) if config.lambda_tempering: lambda_one_at = int(config.lambda_one_at * config.steps) lambdas = torch.cat([torch.linspace(0, 1, lambda_one_at + 1), torch.ones(config.steps - lambda_one_at)]) else: lambdas = None batch_p = _get_batch_size_based_on_gpu_mem_smc_grad(device, config.phi) images = pipe( prompt=config.prompt, reward_fn=image_reward_fn, resample_fn=lambda log_w: resample(log_w, ess_threshold=config.ess_threshold, partial=config.partial_resampling), resample_frequency=config.resample_frequency, negative_prompt=config.negative_prompt, height=config.resolution, width=config.resolution, guidance_scale=config.CFG, num_inference_steps=config.steps, kl_weight=config.kl_weight, lambdas=lambdas, batches=config.num_batches, num_particles=config.num_particles, batch_p = batch_p, proposal_type="locally_optimal", use_continuous_formulation=config.use_continuous_formulation, phi=config.phi, tau=config.tau, output_type="pt", ) image_rewards = (image_reward_fn(images) - reward_bias).tolist() pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3 return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used) @dataclass class FTInferenceConfig: prompt: str negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" resolution: int = 512 CFG: float = 9.0 steps: int = 48 num_batches: int = 4 ckpt_uuid: str = "a1e906e1-16a9-44a3-abe8-6dd2c17e12a2" def infer_ft(config: FTInferenceConfig, device='cpu'): with pipe_build_lock: pipe = build_pipe(device) return infer_ft_with_pipe(config, pipe, device=device) def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') -> int: setup_duration = 30.0 step_duration = 1.0 total_duration = math.ceil(setup_duration + step_duration * config.steps) return max(total_duration, MIN_GPU_DURATION) @spaces.GPU(duration=_get_ft_duration) def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'): if isinstance(device, str): device = torch.device(device) with device_load_lock: pipe = pipe.to(device) load_lora_weights(pipe, config.ckpt_uuid) reward_bias = 5.0 with reward_model_build_lock: reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5" image_reward_fn = lambda images: reward_fn( images, [config.prompt] * len(images) ) images = pipe( prompt=config.prompt, reward_fn=image_reward_fn, resample_fn=lambda log_w: resample(log_w), negative_prompt=config.negative_prompt, height=config.resolution, width=config.resolution, guidance_scale=config.CFG, num_inference_steps=config.steps, batches=config.num_batches, num_particles=1, batch_p=config.num_batches, proposal_type="without_SMC", output_type="pt", ) image_rewards = (image_reward_fn(images) - reward_bias).tolist() pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3 return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)