Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Copyright (c)
|
| 2 |
|
| 3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
# of this software and associated documentation files (the "Software"), to deal
|
|
@@ -19,7 +19,7 @@
|
|
| 19 |
# SOFTWARE.
|
| 20 |
|
| 21 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
| 22 |
-
from diffusers import
|
| 23 |
|
| 24 |
import torch
|
| 25 |
import torch.nn as nn
|
|
@@ -31,15 +31,15 @@ from typing import Tuple, List, Literal, Optional, Union
|
|
| 31 |
from tqdm import tqdm
|
| 32 |
from PIL import Image
|
| 33 |
|
| 34 |
-
from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
| 35 |
|
| 36 |
|
| 37 |
-
class
|
| 38 |
def __init__(
|
| 39 |
self,
|
| 40 |
device: torch.device,
|
| 41 |
dtype: torch.dtype = torch.float16,
|
| 42 |
-
sd_version: Literal['1.5'
|
| 43 |
hf_key: Optional[str] = None,
|
| 44 |
lora_key: Optional[str] = None,
|
| 45 |
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
|
|
@@ -52,8 +52,9 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 52 |
default_preprocess_mask_cover_alpha: float = 0.3,
|
| 53 |
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
|
| 54 |
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
|
|
|
|
| 55 |
) -> None:
|
| 56 |
-
r"""Stabilized
|
| 57 |
|
| 58 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
| 59 |
Model while preserving mask fidelity and quality.
|
|
@@ -95,13 +96,16 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 95 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
| 96 |
where each mask covered by other masks is reduced in its alpha
|
| 97 |
value by this specified factor.
|
| 98 |
-
t_index_list (List[int]): The default scheduling for
|
| 99 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
| 100 |
defines the mask quantization modes. Details in the codes of
|
| 101 |
`self.process_mask`. Basically, this (subtly) controls the
|
| 102 |
smoothness of foreground-background blending. More continuous
|
| 103 |
means more blending, but smaller generated patch depending on
|
| 104 |
the mask standard deviation.
|
|
|
|
|
|
|
|
|
|
| 105 |
"""
|
| 106 |
super().__init__()
|
| 107 |
|
|
@@ -120,30 +124,24 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 120 |
self.mask_type = mask_type
|
| 121 |
|
| 122 |
print(f'[INFO] Loading Stable Diffusion...')
|
| 123 |
-
variant = None
|
| 124 |
lora_weight_name = None
|
| 125 |
if self.sd_version == '1.5':
|
| 126 |
if hf_key is not None:
|
| 127 |
-
print(f'[INFO] Using
|
| 128 |
model_key = hf_key
|
| 129 |
else:
|
| 130 |
model_key = 'runwayml/stable-diffusion-v1-5'
|
| 131 |
-
# variant = 'fp16'
|
| 132 |
lora_key = 'latent-consistency/lcm-lora-sdv1-5'
|
| 133 |
lora_weight_name = 'pytorch_lora_weights.safetensors'
|
| 134 |
-
# elif self.sd_version == 'xl':
|
| 135 |
-
# model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
| 136 |
-
# lora_key = 'latent-consistency/lcm-lora-sdxl'
|
| 137 |
-
# variant = 'fp16'
|
| 138 |
-
# lora_weight_name = 'pytorch_lora_weights.safetensors'
|
| 139 |
else:
|
| 140 |
raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
|
| 141 |
|
| 142 |
# Create model
|
| 143 |
-
|
| 144 |
-
|
|
|
|
| 145 |
|
| 146 |
-
self.pipe =
|
| 147 |
if lora_key is None:
|
| 148 |
print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
|
| 149 |
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
|
@@ -166,7 +164,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 166 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
| 167 |
|
| 168 |
# Prepare white background for bootstrapping.
|
| 169 |
-
|
| 170 |
|
| 171 |
print(f'[INFO] Model is loaded!')
|
| 172 |
|
|
@@ -281,11 +279,14 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 281 |
Returns:
|
| 282 |
A single string of text prompt.
|
| 283 |
"""
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
@torch.no_grad()
|
| 291 |
def encode_imgs(
|
|
@@ -405,7 +406,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 405 |
25, 37], the masks are split into binary masks whose values are
|
| 406 |
greater than these levels. This results in tradual increase of mask
|
| 407 |
region as the timesteps increase. Details are described in our
|
| 408 |
-
paper
|
| 409 |
|
| 410 |
On the Three Modes of `mask_type`:
|
| 411 |
`self.mask_type` is predefined at the initialization stage of this
|
|
@@ -609,7 +610,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 609 |
|
| 610 |
Minimal Example:
|
| 611 |
>>> device = torch.device('cuda:0')
|
| 612 |
-
>>> smd =
|
| 613 |
>>> image = smd.sample('A photo of the dolomites')
|
| 614 |
>>> image.save('my_creation.png')
|
| 615 |
|
|
@@ -675,7 +676,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 675 |
|
| 676 |
Minimal Example:
|
| 677 |
>>> device = torch.device('cuda:0')
|
| 678 |
-
>>> smd =
|
| 679 |
>>> image = smd.sample_panorama(
|
| 680 |
>>> 'A photo of Alps', height=512, width=3072)
|
| 681 |
>>> image.save('my_panorama_creation.png')
|
|
@@ -792,7 +793,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 792 |
|
| 793 |
Example:
|
| 794 |
>>> device = torch.device('cuda:0')
|
| 795 |
-
>>> smd =
|
| 796 |
>>> prompts = {... specify prompts}
|
| 797 |
>>> masks = {... specify mask tensors}
|
| 798 |
>>> height, width = masks.shape[-2:]
|
|
@@ -881,7 +882,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 881 |
|
| 882 |
# prompts is None: return background.
|
| 883 |
# masks is None but prompts is not None: return prompts
|
| 884 |
-
# masks is not None and prompts is not None: Do
|
| 885 |
|
| 886 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
| 887 |
if background is None and background_prompt is not None:
|
|
@@ -1103,4 +1104,4 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
| 1103 |
image = blend(image, background[0], fg_mask)
|
| 1104 |
else:
|
| 1105 |
image = T.ToPILImage()(image)
|
| 1106 |
-
return image
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Jaerin Lee
|
| 2 |
|
| 3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
# of this software and associated documentation files (the "Software"), to deal
|
|
|
|
| 19 |
# SOFTWARE.
|
| 20 |
|
| 21 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
| 22 |
+
from diffusers import LCMScheduler, DDIMScheduler, AutoencoderTiny
|
| 23 |
|
| 24 |
import torch
|
| 25 |
import torch.nn as nn
|
|
|
|
| 31 |
from tqdm import tqdm
|
| 32 |
from PIL import Image
|
| 33 |
|
| 34 |
+
from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
| 35 |
|
| 36 |
|
| 37 |
+
class SemanticDrawPipeline(nn.Module):
|
| 38 |
def __init__(
|
| 39 |
self,
|
| 40 |
device: torch.device,
|
| 41 |
dtype: torch.dtype = torch.float16,
|
| 42 |
+
sd_version: Literal['1.5'] = '1.5',
|
| 43 |
hf_key: Optional[str] = None,
|
| 44 |
lora_key: Optional[str] = None,
|
| 45 |
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
|
|
|
|
| 52 |
default_preprocess_mask_cover_alpha: float = 0.3,
|
| 53 |
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
|
| 54 |
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
|
| 55 |
+
has_i2t: bool = True,
|
| 56 |
) -> None:
|
| 57 |
+
r"""Stabilized regionally assigned texts-to-image generation for fast sampling.
|
| 58 |
|
| 59 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
| 60 |
Model while preserving mask fidelity and quality.
|
|
|
|
| 96 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
| 97 |
where each mask covered by other masks is reduced in its alpha
|
| 98 |
value by this specified factor.
|
| 99 |
+
t_index_list (List[int]): The default scheduling for the scheduler.
|
| 100 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
| 101 |
defines the mask quantization modes. Details in the codes of
|
| 102 |
`self.process_mask`. Basically, this (subtly) controls the
|
| 103 |
smoothness of foreground-background blending. More continuous
|
| 104 |
means more blending, but smaller generated patch depending on
|
| 105 |
the mask standard deviation.
|
| 106 |
+
has_i2t (bool): Automatic background image to text prompt con-
|
| 107 |
+
version with BLIP-2 model. May not be necessary for the non-
|
| 108 |
+
streaming application.
|
| 109 |
"""
|
| 110 |
super().__init__()
|
| 111 |
|
|
|
|
| 124 |
self.mask_type = mask_type
|
| 125 |
|
| 126 |
print(f'[INFO] Loading Stable Diffusion...')
|
|
|
|
| 127 |
lora_weight_name = None
|
| 128 |
if self.sd_version == '1.5':
|
| 129 |
if hf_key is not None:
|
| 130 |
+
print(f'[INFO] Using custom model key: {hf_key}')
|
| 131 |
model_key = hf_key
|
| 132 |
else:
|
| 133 |
model_key = 'runwayml/stable-diffusion-v1-5'
|
|
|
|
| 134 |
lora_key = 'latent-consistency/lcm-lora-sdv1-5'
|
| 135 |
lora_weight_name = 'pytorch_lora_weights.safetensors'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
else:
|
| 137 |
raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
|
| 138 |
|
| 139 |
# Create model
|
| 140 |
+
if has_i2t:
|
| 141 |
+
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
| 142 |
+
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
| 143 |
|
| 144 |
+
self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype)
|
| 145 |
if lora_key is None:
|
| 146 |
print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
|
| 147 |
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
|
|
|
| 164 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
| 165 |
|
| 166 |
# Prepare white background for bootstrapping.
|
| 167 |
+
self.get_white_background(768, 768)
|
| 168 |
|
| 169 |
print(f'[INFO] Model is loaded!')
|
| 170 |
|
|
|
|
| 279 |
Returns:
|
| 280 |
A single string of text prompt.
|
| 281 |
"""
|
| 282 |
+
if hasattr(self, 'i2t_model'):
|
| 283 |
+
question = 'Question: What are in the image? Answer:'
|
| 284 |
+
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
| 285 |
+
out = self.i2t_model.generate(**inputs, max_new_tokens=77)
|
| 286 |
+
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
| 287 |
+
return prompt
|
| 288 |
+
else:
|
| 289 |
+
return ''
|
| 290 |
|
| 291 |
@torch.no_grad()
|
| 292 |
def encode_imgs(
|
|
|
|
| 406 |
25, 37], the masks are split into binary masks whose values are
|
| 407 |
greater than these levels. This results in tradual increase of mask
|
| 408 |
region as the timesteps increase. Details are described in our
|
| 409 |
+
paper.
|
| 410 |
|
| 411 |
On the Three Modes of `mask_type`:
|
| 412 |
`self.mask_type` is predefined at the initialization stage of this
|
|
|
|
| 610 |
|
| 611 |
Minimal Example:
|
| 612 |
>>> device = torch.device('cuda:0')
|
| 613 |
+
>>> smd = SemanticDrawPipeline(device)
|
| 614 |
>>> image = smd.sample('A photo of the dolomites')
|
| 615 |
>>> image.save('my_creation.png')
|
| 616 |
|
|
|
|
| 676 |
|
| 677 |
Minimal Example:
|
| 678 |
>>> device = torch.device('cuda:0')
|
| 679 |
+
>>> smd = SemanticDrawPipeline(device)
|
| 680 |
>>> image = smd.sample_panorama(
|
| 681 |
>>> 'A photo of Alps', height=512, width=3072)
|
| 682 |
>>> image.save('my_panorama_creation.png')
|
|
|
|
| 793 |
|
| 794 |
Example:
|
| 795 |
>>> device = torch.device('cuda:0')
|
| 796 |
+
>>> smd = SemanticDrawPipeline(device)
|
| 797 |
>>> prompts = {... specify prompts}
|
| 798 |
>>> masks = {... specify mask tensors}
|
| 799 |
>>> height, width = masks.shape[-2:]
|
|
|
|
| 882 |
|
| 883 |
# prompts is None: return background.
|
| 884 |
# masks is None but prompts is not None: return prompts
|
| 885 |
+
# masks is not None and prompts is not None: Do SemanticDraw.
|
| 886 |
|
| 887 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
| 888 |
if background is None and background_prompt is not None:
|
|
|
|
| 1104 |
image = blend(image, background[0], fg_mask)
|
| 1105 |
else:
|
| 1106 |
image = T.ToPILImage()(image)
|
| 1107 |
+
return image
|