Update genstereo/GenStereo.py
Browse files- genstereo/GenStereo.py +17 -9
genstereo/GenStereo.py
CHANGED
|
@@ -24,7 +24,7 @@ from .models import (
|
|
| 24 |
UNet3DConditionModel,
|
| 25 |
ReferenceAttentionControl
|
| 26 |
)
|
| 27 |
-
from .ops import
|
| 28 |
|
| 29 |
class AdaptiveFusionLayer(nn.Module):
|
| 30 |
def __init__(self):
|
|
@@ -47,8 +47,8 @@ class GenStereo():
|
|
| 47 |
pretrained_model_path: str = ''
|
| 48 |
checkpoint_name: str = ''
|
| 49 |
half_precision_weights: bool = False
|
| 50 |
-
height: int =
|
| 51 |
-
width: int =
|
| 52 |
num_inference_steps: int = 50
|
| 53 |
guidance_scale: float = 1.5
|
| 54 |
cfg: Config
|
|
@@ -88,18 +88,28 @@ class GenStereo():
|
|
| 88 |
def __init__(
|
| 89 |
self,
|
| 90 |
cfg: Optional[Union[dict, DictConfig]] = None,
|
| 91 |
-
device: Optional[str] = 'cuda:0'
|
|
|
|
| 92 |
) -> None:
|
| 93 |
self.cfg = OmegaConf.structured(self.Config(**cfg))
|
| 94 |
self.model_path = join(
|
| 95 |
self.cfg.pretrained_model_path, self.cfg.checkpoint_name
|
| 96 |
)
|
| 97 |
self.device = device
|
|
|
|
| 98 |
self.configure()
|
| 99 |
self.transform_pixels = transforms.Compose([
|
| 100 |
transforms.ToTensor(), # Converts image to Tensor
|
| 101 |
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
|
| 102 |
-
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
def configure(self) -> None:
|
| 105 |
print(f"Loading GenStereo...")
|
|
@@ -108,10 +118,6 @@ class GenStereo():
|
|
| 108 |
self.dtype = (
|
| 109 |
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
| 110 |
)
|
| 111 |
-
self.viewport_mtx: Float[Tensor, 'B 4 4'] = get_viewport_matrix(
|
| 112 |
-
self.cfg.width, self.cfg.height,
|
| 113 |
-
batch_size=1, device=self.device
|
| 114 |
-
).to(self.dtype)
|
| 115 |
|
| 116 |
# Load models.
|
| 117 |
self.load_models()
|
|
@@ -276,6 +282,8 @@ class GenStereo():
|
|
| 276 |
).image_embeds
|
| 277 |
|
| 278 |
image_prompt_embeds = clip_image_embeds.unsqueeze(1)
|
|
|
|
|
|
|
| 279 |
uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
|
| 280 |
|
| 281 |
image_prompt_embeds = torch.cat(
|
|
|
|
| 24 |
UNet3DConditionModel,
|
| 25 |
ReferenceAttentionControl
|
| 26 |
)
|
| 27 |
+
from .ops import convert_left_to_right, convert_left_to_right_torch
|
| 28 |
|
| 29 |
class AdaptiveFusionLayer(nn.Module):
|
| 30 |
def __init__(self):
|
|
|
|
| 47 |
pretrained_model_path: str = ''
|
| 48 |
checkpoint_name: str = ''
|
| 49 |
half_precision_weights: bool = False
|
| 50 |
+
height: int = 768
|
| 51 |
+
width: int = 768
|
| 52 |
num_inference_steps: int = 50
|
| 53 |
guidance_scale: float = 1.5
|
| 54 |
cfg: Config
|
|
|
|
| 88 |
def __init__(
|
| 89 |
self,
|
| 90 |
cfg: Optional[Union[dict, DictConfig]] = None,
|
| 91 |
+
device: Optional[str] = 'cuda:0',
|
| 92 |
+
sd_version: Optional[str] = 'v2.1'
|
| 93 |
) -> None:
|
| 94 |
self.cfg = OmegaConf.structured(self.Config(**cfg))
|
| 95 |
self.model_path = join(
|
| 96 |
self.cfg.pretrained_model_path, self.cfg.checkpoint_name
|
| 97 |
)
|
| 98 |
self.device = device
|
| 99 |
+
self.sd_version = sd_version
|
| 100 |
self.configure()
|
| 101 |
self.transform_pixels = transforms.Compose([
|
| 102 |
transforms.ToTensor(), # Converts image to Tensor
|
| 103 |
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
|
| 104 |
+
])
|
| 105 |
+
if self.sd_version == "v1.5":
|
| 106 |
+
self.cfg.height = 512
|
| 107 |
+
self.cfg.width = 512
|
| 108 |
+
elif self.sd_version == "v2.1":
|
| 109 |
+
self.cfg.height = 768
|
| 110 |
+
self.cfg.width = 768
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unknown SD version: {self.sd_version}")
|
| 113 |
|
| 114 |
def configure(self) -> None:
|
| 115 |
print(f"Loading GenStereo...")
|
|
|
|
| 118 |
self.dtype = (
|
| 119 |
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
# Load models.
|
| 123 |
self.load_models()
|
|
|
|
| 282 |
).image_embeds
|
| 283 |
|
| 284 |
image_prompt_embeds = clip_image_embeds.unsqueeze(1)
|
| 285 |
+
if self.sd_version == "v2.1":
|
| 286 |
+
image_prompt_embeds = F.pad(image_prompt_embeds, (0, 256), "constant", 0) # Now shape is (bs, 1, 1024)
|
| 287 |
uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
|
| 288 |
|
| 289 |
image_prompt_embeds = torch.cat(
|