Spaces:
Build error
Build error
Matteo Sirri
commited on
Commit
·
d41f3b7
1
Parent(s):
50aa67b
feat: add cuda support
Browse files
app.py
CHANGED
|
@@ -12,7 +12,7 @@ import torchvision.transforms as T
|
|
| 12 |
|
| 13 |
logging.getLogger('PIL').setLevel(logging.CRITICAL)
|
| 14 |
|
| 15 |
-
device = torch.device("cpu")
|
| 16 |
|
| 17 |
|
| 18 |
def load_model(baseline: bool = False):
|
|
@@ -35,7 +35,9 @@ def frcnn_motsynth(image):
|
|
| 35 |
model = load_model()
|
| 36 |
transformEval = presets.DetectionPresetEval()
|
| 37 |
image_tensor = transformEval(image, None)[0]
|
|
|
|
| 38 |
prediction = model([image_tensor])[0]
|
|
|
|
| 39 |
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
|
| 40 |
torchvision.io.write_png(image_w_bbox, "custom_out.png")
|
| 41 |
return "custom_out.png"
|
|
@@ -45,7 +47,9 @@ def frcnn_coco(image):
|
|
| 45 |
model = load_model(baseline=True)
|
| 46 |
transformEval = presets.DetectionPresetEval()
|
| 47 |
image_tensor = transformEval(image, None)[0]
|
|
|
|
| 48 |
prediction = model([image_tensor])[0]
|
|
|
|
| 49 |
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
|
| 50 |
torchvision.io.write_png(image_w_bbox, "baseline_out.png")
|
| 51 |
return "baseline_out.png"
|
|
|
|
| 12 |
|
| 13 |
logging.getLogger('PIL').setLevel(logging.CRITICAL)
|
| 14 |
|
| 15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
|
| 17 |
|
| 18 |
def load_model(baseline: bool = False):
|
|
|
|
| 35 |
model = load_model()
|
| 36 |
transformEval = presets.DetectionPresetEval()
|
| 37 |
image_tensor = transformEval(image, None)[0]
|
| 38 |
+
image_tensor = image_tensor.to(device)
|
| 39 |
prediction = model([image_tensor])[0]
|
| 40 |
+
prediction = prediction.to(device)
|
| 41 |
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
|
| 42 |
torchvision.io.write_png(image_w_bbox, "custom_out.png")
|
| 43 |
return "custom_out.png"
|
|
|
|
| 47 |
model = load_model(baseline=True)
|
| 48 |
transformEval = presets.DetectionPresetEval()
|
| 49 |
image_tensor = transformEval(image, None)[0]
|
| 50 |
+
image_tensor = image_tensor.to(device)
|
| 51 |
prediction = model([image_tensor])[0]
|
| 52 |
+
prediction = prediction.to(device)
|
| 53 |
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
|
| 54 |
torchvision.io.write_png(image_w_bbox, "baseline_out.png")
|
| 55 |
return "baseline_out.png"
|