Update app.py
Browse files
app.py
CHANGED
|
@@ -60,7 +60,7 @@ elif ckpt.endswith(".safetensors"):
|
|
| 60 |
model.load_state_dict(model_ckpt)
|
| 61 |
else:
|
| 62 |
raise NotImplementedError
|
| 63 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 64 |
model = model.to(device)
|
| 65 |
sampler = DDIMSampler(model)
|
| 66 |
|
|
@@ -116,7 +116,7 @@ opt.prompt = ''
|
|
| 116 |
opt.text_dir = True
|
| 117 |
opt.front_dir = '+z'
|
| 118 |
opt.force_cuda_rast = True
|
| 119 |
-
|
| 120 |
gui = GUI(opt)
|
| 121 |
###################################### INIT STAGE 2 #########################################
|
| 122 |
|
|
@@ -348,7 +348,7 @@ def process_stage2(input_model, input_text, input_dir, iters, output_model, outp
|
|
| 348 |
opt.mesh = input_model
|
| 349 |
|
| 350 |
# load mesh!
|
| 351 |
-
gui.renderer = gui.renderer_class(opt,
|
| 352 |
|
| 353 |
# set prompt
|
| 354 |
gui.prompt = opt.positive_prompt + ', ' + input_text
|
|
|
|
| 60 |
model.load_state_dict(model_ckpt)
|
| 61 |
else:
|
| 62 |
raise NotImplementedError
|
| 63 |
+
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
|
| 64 |
model = model.to(device)
|
| 65 |
sampler = DDIMSampler(model)
|
| 66 |
|
|
|
|
| 116 |
opt.text_dir = True
|
| 117 |
opt.front_dir = '+z'
|
| 118 |
opt.force_cuda_rast = True
|
| 119 |
+
device0 = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 120 |
gui = GUI(opt)
|
| 121 |
###################################### INIT STAGE 2 #########################################
|
| 122 |
|
|
|
|
| 348 |
opt.mesh = input_model
|
| 349 |
|
| 350 |
# load mesh!
|
| 351 |
+
gui.renderer = gui.renderer_class(opt, device0).to(device0)
|
| 352 |
|
| 353 |
# set prompt
|
| 354 |
gui.prompt = opt.positive_prompt + ', ' + input_text
|