gangweix commited on
Commit
8a61297
·
verified ·
1 Parent(s): 219e9b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -51,6 +51,8 @@ css = """
51
 
52
  set_seed(666)
53
 
 
 
54
  # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
55
  # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
56
  default_steps = 20
@@ -63,8 +65,10 @@ ckpt_path = hf_hub_download(
63
  state_dict = torch.load(ckpt_path, map_location="cpu")
64
  model.load_state_dict(state_dict, strict=False)
65
  model = model.eval()
 
66
 
67
  moge_model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").eval()
 
68
 
69
 
70
  def main(share=True):
 
51
 
52
  set_seed(666)
53
 
54
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
  # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
  # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
58
  default_steps = 20
 
65
  state_dict = torch.load(ckpt_path, map_location="cpu")
66
  model.load_state_dict(state_dict, strict=False)
67
  model = model.eval()
68
+ model = model.to(DEVICE)
69
 
70
  moge_model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").eval()
71
+ moge_model = moge_model.to(DEVICE)
72
 
73
 
74
  def main(share=True):