minmax performance on those tensor cores lol
Browse files- inference_gradio.py +2 -2
inference_gradio.py
CHANGED
|
@@ -123,7 +123,7 @@ model.eval()
|
|
| 123 |
if torch.cuda.is_available():
|
| 124 |
model.cuda()
|
| 125 |
if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
|
| 126 |
-
model.
|
| 127 |
|
| 128 |
with open("JTP_PILOT/tags.json", "r") as file:
|
| 129 |
tags = json.load(file) # type: dict
|
|
@@ -139,7 +139,7 @@ def create_tags(image, threshold):
|
|
| 139 |
if torch.cuda.is_available():
|
| 140 |
tensor.cuda()
|
| 141 |
if torch.cuda.get_device_capability()[0] >= 7:
|
| 142 |
-
tensor.
|
| 143 |
|
| 144 |
with torch.no_grad():
|
| 145 |
logits = model(tensor)
|
|
|
|
| 123 |
if torch.cuda.is_available():
|
| 124 |
model.cuda()
|
| 125 |
if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
|
| 126 |
+
model.to(dtype=torch.float16, memory_format=torch.channels_last)
|
| 127 |
|
| 128 |
with open("JTP_PILOT/tags.json", "r") as file:
|
| 129 |
tags = json.load(file) # type: dict
|
|
|
|
| 139 |
if torch.cuda.is_available():
|
| 140 |
tensor.cuda()
|
| 141 |
if torch.cuda.get_device_capability()[0] >= 7:
|
| 142 |
+
tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
|
| 143 |
|
| 144 |
with torch.no_grad():
|
| 145 |
logits = model(tensor)
|