Spaces:
Build error
Build error
bugfix: fix cuda: + pytorch latest version, input type and weight type should be the same
Browse files
app.py
CHANGED
|
@@ -56,21 +56,34 @@ miyazaki_model = Transformer()
|
|
| 56 |
kon_model = Transformer()
|
| 57 |
|
| 58 |
enable_gpu = torch.cuda.is_available()
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
shinkai_model.load_state_dict(
|
| 62 |
-
torch.load(shinkai_model_hfhub,
|
| 63 |
)
|
| 64 |
hosoda_model.load_state_dict(
|
| 65 |
-
torch.load(hosoda_model_hfhub,
|
| 66 |
)
|
| 67 |
miyazaki_model.load_state_dict(
|
| 68 |
-
torch.load(miyazaki_model_hfhub,
|
| 69 |
)
|
| 70 |
kon_model.load_state_dict(
|
| 71 |
-
torch.load(kon_model_hfhub,
|
| 72 |
)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
shinkai_model.eval()
|
| 75 |
hosoda_model.eval()
|
| 76 |
miyazaki_model.eval()
|
|
@@ -118,7 +131,8 @@ def inference(img, style):
|
|
| 118 |
|
| 119 |
if enable_gpu:
|
| 120 |
logger.info(f"CUDA found. Using GPU.")
|
| 121 |
-
|
|
|
|
| 122 |
else:
|
| 123 |
logger.info(f"CUDA not found. Using CPU.")
|
| 124 |
input_image = Variable(input_image).float()
|
|
|
|
| 56 |
kon_model = Transformer()
|
| 57 |
|
| 58 |
enable_gpu = torch.cuda.is_available()
|
| 59 |
+
|
| 60 |
+
if enable_gpu:
|
| 61 |
+
# If you have multiple cards,
|
| 62 |
+
# you can assign to a specific card, eg: "cuda:0"("cuda") or "cuda:1"
|
| 63 |
+
# Use the first card by default: "cuda"
|
| 64 |
+
device = torch.device("cuda")
|
| 65 |
+
else:
|
| 66 |
+
device = "cpu"
|
| 67 |
|
| 68 |
shinkai_model.load_state_dict(
|
| 69 |
+
torch.load(shinkai_model_hfhub, device)
|
| 70 |
)
|
| 71 |
hosoda_model.load_state_dict(
|
| 72 |
+
torch.load(hosoda_model_hfhub, device)
|
| 73 |
)
|
| 74 |
miyazaki_model.load_state_dict(
|
| 75 |
+
torch.load(miyazaki_model_hfhub, device)
|
| 76 |
)
|
| 77 |
kon_model.load_state_dict(
|
| 78 |
+
torch.load(kon_model_hfhub, device)
|
| 79 |
)
|
| 80 |
|
| 81 |
+
if enable_gpu:
|
| 82 |
+
shinkai_model = shinkai_model.to(device)
|
| 83 |
+
hosoda_model = hosoda_model.to(device)
|
| 84 |
+
miyazaki_model = miyazaki_model.to(device)
|
| 85 |
+
kon_model = kon_model.to(device)
|
| 86 |
+
|
| 87 |
shinkai_model.eval()
|
| 88 |
hosoda_model.eval()
|
| 89 |
miyazaki_model.eval()
|
|
|
|
| 131 |
|
| 132 |
if enable_gpu:
|
| 133 |
logger.info(f"CUDA found. Using GPU.")
|
| 134 |
+
# Allows to specify a card for calculation
|
| 135 |
+
input_image = Variable(input_image).to(device)
|
| 136 |
else:
|
| 137 |
logger.info(f"CUDA not found. Using CPU.")
|
| 138 |
input_image = Variable(input_image).float()
|