Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import threading | |
| import spaces | |
| from dataclasses import dataclass | |
| import torch | |
| from transformers import ( | |
| CLIPTextModelWithProjection, | |
| CLIPTokenizer, | |
| ) | |
| from diffusers.models.autoencoders.vq_model import VQModel | |
| from src.smc.transformer import Transformer2DModel | |
| from src.smc.pipeline import Pipeline | |
| from src.meissonic.scheduler import Scheduler | |
| from src.smc.scheduler import ReMDMScheduler, MeissonicScheduler | |
| import src.smc.rewards as rewards | |
| from src.smc.resampling import resample | |
| from PIL import Image | |
| from typing import List | |
| MIN_GPU_DURATION = 60 | |
| pipe_build_lock = threading.Lock() | |
| reward_model_build_lock = threading.Lock() | |
| device_load_lock = threading.Lock() | |
| def build_pipe(device): | |
| model_path = "Collov-Labs/Monetico" | |
| dtype = torch.bfloat16 | |
| model = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype) | |
| vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", torch_dtype=dtype) | |
| text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) # better for Monetico | |
| tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype) | |
| scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler", torch_dtype=dtype) | |
| scheduler_new = MeissonicScheduler( | |
| mask_token_id=scheduler.config.mask_token_id, # type: ignore | |
| masking_schedule=scheduler.config.masking_schedule, # type: ignore | |
| device=device, | |
| ) | |
| pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new) | |
| return pipe | |
| def load_lora_weights(pipe, lora_ckpt_uuid): | |
| # LORA lora checkpoint | |
| ckpt_path = os.path.join('checkpoints', lora_ckpt_uuid) | |
| pipe.load_lora_weights( | |
| pretrained_model_name_or_path_or_dict=ckpt_path, | |
| ) | |
| class InferenceOutput: | |
| images: List[Image.Image] | |
| image_rewards: List[float] | |
| gpu_mem_used: float | |
| class PretrainedInferenceConfig: | |
| prompt: str | |
| negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" | |
| resolution: int = 512 | |
| CFG: float = 9.0 | |
| steps: int = 48 | |
| num_batches: int = 4 | |
| def infer_pretrained(config: PretrainedInferenceConfig, device='cpu'): | |
| with pipe_build_lock: | |
| pipe = build_pipe(device) | |
| return infer_pretrained_with_pipe(config, pipe, device=device) | |
| def _get_pretrained_duration(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu') -> int: | |
| setup_duration = 30.0 | |
| step_duration = 1.0 | |
| total_duration = math.ceil(setup_duration + step_duration * config.steps) | |
| return max(total_duration, MIN_GPU_DURATION) | |
| def infer_pretrained_with_pipe(config: PretrainedInferenceConfig, pipe: Pipeline, device='cpu'): | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| with device_load_lock: | |
| pipe = pipe.to(device) | |
| reward_bias = 5.0 | |
| with reward_model_build_lock: | |
| reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5" | |
| image_reward_fn = lambda images: reward_fn( | |
| images, | |
| [config.prompt] * len(images) | |
| ) | |
| images = pipe( | |
| prompt=config.prompt, | |
| reward_fn=image_reward_fn, | |
| resample_fn=lambda log_w: resample(log_w), | |
| negative_prompt=config.negative_prompt, | |
| height=config.resolution, | |
| width=config.resolution, | |
| guidance_scale=config.CFG, | |
| num_inference_steps=config.steps, | |
| batches=config.num_batches, | |
| num_particles=1, | |
| batch_p=config.num_batches, | |
| proposal_type="without_SMC", | |
| output_type="pt", | |
| ) | |
| image_rewards = (image_reward_fn(images) - reward_bias).tolist() | |
| pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore | |
| gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3 | |
| return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used) | |
| class SMCGradInferenceConfig: | |
| prompt: str | |
| negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" | |
| ess_threshold: float = 1.0 | |
| partial_resampling: bool = False | |
| resample_frequency: int = 4 | |
| resolution: int = 512 | |
| CFG: float = 9.0 | |
| steps: int = 48 | |
| kl_weight: float = 0.02 | |
| lambda_tempering: bool = True | |
| lambda_one_at: float = 0.3 | |
| num_batches: int = 4 | |
| num_particles: int = 8 | |
| proposal_type: str = "locally_optimal" | |
| use_continuous_formulation: bool = True | |
| phi: int = 1 | |
| tau: float = 1.0 | |
| def _get_batch_size_based_on_gpu_mem_smc_grad(device, phi): | |
| if device.type == "cuda": | |
| gpu_id = device.index if device.index is not None else 0 | |
| total_mem_gb = torch.cuda.get_device_properties(gpu_id).total_memory / 1024 ** 3 | |
| if phi == 1: | |
| if total_mem_gb < 24: | |
| batch_p = 1 | |
| elif total_mem_gb < 48: | |
| batch_p = 4 | |
| elif total_mem_gb < 70: | |
| batch_p = 7 | |
| else: | |
| batch_p = 8 | |
| elif phi <= 4: | |
| if total_mem_gb < 48: | |
| batch_p = 1 | |
| elif total_mem_gb < 70: | |
| batch_p = 2 | |
| else: | |
| batch_p = 4 | |
| else: | |
| batch_p = 1 | |
| else: | |
| batch_p = 1 | |
| return batch_p | |
| def infer_smc_grad(config: SMCGradInferenceConfig, device='cpu'): | |
| with pipe_build_lock: | |
| pipe = build_pipe(device) | |
| return infer_smc_grad_with_pipe(config, pipe, device=device) | |
| def _get_smc_grad_duration(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu') -> int: | |
| setup_duration = 30.0 | |
| step_duration = 6.0 | |
| total_duration = math.ceil(setup_duration + step_duration * config.steps) | |
| return max(total_duration, MIN_GPU_DURATION) | |
| def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, device='cpu'): | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| with device_load_lock: | |
| pipe = pipe.to(device) | |
| reward_bias = 5.0 | |
| with reward_model_build_lock: | |
| reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5" | |
| image_reward_fn = lambda images: reward_fn( | |
| images, | |
| [config.prompt] * len(images) | |
| ) | |
| if config.lambda_tempering: | |
| lambda_one_at = int(config.lambda_one_at * config.steps) | |
| lambdas = torch.cat([torch.linspace(0, 1, lambda_one_at + 1), torch.ones(config.steps - lambda_one_at)]) | |
| else: | |
| lambdas = None | |
| batch_p = _get_batch_size_based_on_gpu_mem_smc_grad(device, config.phi) | |
| images = pipe( | |
| prompt=config.prompt, | |
| reward_fn=image_reward_fn, | |
| resample_fn=lambda log_w: resample(log_w, ess_threshold=config.ess_threshold, partial=config.partial_resampling), | |
| resample_frequency=config.resample_frequency, | |
| negative_prompt=config.negative_prompt, | |
| height=config.resolution, | |
| width=config.resolution, | |
| guidance_scale=config.CFG, | |
| num_inference_steps=config.steps, | |
| kl_weight=config.kl_weight, | |
| lambdas=lambdas, | |
| batches=config.num_batches, | |
| num_particles=config.num_particles, | |
| batch_p = batch_p, | |
| proposal_type="locally_optimal", | |
| use_continuous_formulation=config.use_continuous_formulation, | |
| phi=config.phi, | |
| tau=config.tau, | |
| output_type="pt", | |
| ) | |
| image_rewards = (image_reward_fn(images) - reward_bias).tolist() | |
| pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore | |
| gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3 | |
| return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used) | |
| class FTInferenceConfig: | |
| prompt: str | |
| negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" | |
| resolution: int = 512 | |
| CFG: float = 9.0 | |
| steps: int = 48 | |
| num_batches: int = 4 | |
| ckpt_uuid: str = "a1e906e1-16a9-44a3-abe8-6dd2c17e12a2" | |
| def infer_ft(config: FTInferenceConfig, device='cpu'): | |
| with pipe_build_lock: | |
| pipe = build_pipe(device) | |
| return infer_ft_with_pipe(config, pipe, device=device) | |
| def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') -> int: | |
| setup_duration = 30.0 | |
| step_duration = 1.0 | |
| total_duration = math.ceil(setup_duration + step_duration * config.steps) | |
| return max(total_duration, MIN_GPU_DURATION) | |
| def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'): | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| with device_load_lock: | |
| pipe = pipe.to(device) | |
| load_lora_weights(pipe, config.ckpt_uuid) | |
| reward_bias = 5.0 | |
| with reward_model_build_lock: | |
| reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, device_load_lock=device_load_lock, bias=reward_bias), "image_reward_plus_5" | |
| image_reward_fn = lambda images: reward_fn( | |
| images, | |
| [config.prompt] * len(images) | |
| ) | |
| images = pipe( | |
| prompt=config.prompt, | |
| reward_fn=image_reward_fn, | |
| resample_fn=lambda log_w: resample(log_w), | |
| negative_prompt=config.negative_prompt, | |
| height=config.resolution, | |
| width=config.resolution, | |
| guidance_scale=config.CFG, | |
| num_inference_steps=config.steps, | |
| batches=config.num_batches, | |
| num_particles=1, | |
| batch_p=config.num_batches, | |
| proposal_type="without_SMC", | |
| output_type="pt", | |
| ) | |
| image_rewards = (image_reward_fn(images) - reward_bias).tolist() | |
| pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore | |
| gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3 | |
| return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used) | |