Spaces:
Runtime error
Runtime error
Commit ·
448a2e3
1
Parent(s): e2e2967
Update model.py
Browse files
model.py
CHANGED
|
@@ -211,8 +211,10 @@ class Model:
|
|
| 211 |
device = 'cuda'
|
| 212 |
if base_model == 'sd-v1-4.ckpt':
|
| 213 |
model = self.model
|
|
|
|
| 214 |
else:
|
| 215 |
model = self.model_anything
|
|
|
|
| 216 |
# if current_base != base_model:
|
| 217 |
# ckpt = os.path.join("models", base_model)
|
| 218 |
# pl_sd = torch.load(ckpt, map_location="cpu")
|
|
@@ -254,7 +256,7 @@ class Model:
|
|
| 254 |
shape = [4, 64, 64]
|
| 255 |
|
| 256 |
# sampling
|
| 257 |
-
samples_ddim, _ =
|
| 258 |
conditioning=c,
|
| 259 |
batch_size=1,
|
| 260 |
shape=shape,
|
|
@@ -283,8 +285,10 @@ class Model:
|
|
| 283 |
device = 'cuda'
|
| 284 |
if base_model == 'sd-v1-4.ckpt':
|
| 285 |
model = self.model
|
|
|
|
| 286 |
else:
|
| 287 |
model = self.model_anything
|
|
|
|
| 288 |
# if current_base != base_model:
|
| 289 |
# ckpt = os.path.join("models", base_model)
|
| 290 |
# pl_sd = torch.load(ckpt, map_location="cpu")
|
|
@@ -347,7 +351,7 @@ class Model:
|
|
| 347 |
shape = [4, 64, 64]
|
| 348 |
|
| 349 |
# sampling
|
| 350 |
-
samples_ddim, _ =
|
| 351 |
conditioning=c,
|
| 352 |
batch_size=1,
|
| 353 |
shape=shape,
|
|
|
|
| 211 |
device = 'cuda'
|
| 212 |
if base_model == 'sd-v1-4.ckpt':
|
| 213 |
model = self.model
|
| 214 |
+
sampler = self.sampler
|
| 215 |
else:
|
| 216 |
model = self.model_anything
|
| 217 |
+
sampler = self.sampler_anything
|
| 218 |
# if current_base != base_model:
|
| 219 |
# ckpt = os.path.join("models", base_model)
|
| 220 |
# pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
| 256 |
shape = [4, 64, 64]
|
| 257 |
|
| 258 |
# sampling
|
| 259 |
+
samples_ddim, _ = sampler.sample(S=50,
|
| 260 |
conditioning=c,
|
| 261 |
batch_size=1,
|
| 262 |
shape=shape,
|
|
|
|
| 285 |
device = 'cuda'
|
| 286 |
if base_model == 'sd-v1-4.ckpt':
|
| 287 |
model = self.model
|
| 288 |
+
sampler = self.sampler
|
| 289 |
else:
|
| 290 |
model = self.model_anything
|
| 291 |
+
sampler = self.sampler_anything
|
| 292 |
# if current_base != base_model:
|
| 293 |
# ckpt = os.path.join("models", base_model)
|
| 294 |
# pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
| 351 |
shape = [4, 64, 64]
|
| 352 |
|
| 353 |
# sampling
|
| 354 |
+
samples_ddim, _ = sampler.sample(S=50,
|
| 355 |
conditioning=c,
|
| 356 |
batch_size=1,
|
| 357 |
shape=shape,
|