Spaces:
Build error
Build error
Commit ·
26d4625
1
Parent(s): e71c28e
Update detect_from_videos.py
Browse files- detect_from_videos.py +3 -3
detect_from_videos.py
CHANGED
|
@@ -12,7 +12,7 @@ from model_core import Two_Stream_Net
|
|
| 12 |
from torchvision import transforms
|
| 13 |
|
| 14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
-
|
| 16 |
|
| 17 |
xception_default_data_transforms_256 = {
|
| 18 |
'train': transforms.Compose([
|
|
@@ -115,7 +115,7 @@ def predict_with_model(image, model, post_function=nn.Softmax(dim=1),
|
|
| 115 |
|
| 116 |
|
| 117 |
def test_full_image_network(video_path, model_path, output_path,
|
| 118 |
-
start_frame=0, end_frame=None, cuda=
|
| 119 |
"""
|
| 120 |
Reads a video and evaluates a subset of frames with the a detection network
|
| 121 |
that takes in a full frame. Outputs are only given if a face is present
|
|
@@ -150,7 +150,7 @@ def test_full_image_network(video_path, model_path, output_path,
|
|
| 150 |
# Load model
|
| 151 |
# model, *_ = model_selection(modelname='xception', num_out_classes=2)
|
| 152 |
model = Two_Stream_Net()
|
| 153 |
-
model.load_state_dict(torch.load(model_path))
|
| 154 |
model = model.to(device)
|
| 155 |
model.eval()
|
| 156 |
|
|
|
|
| 12 |
from torchvision import transforms
|
| 13 |
|
| 14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
map_location=torch.device('cpu')
|
| 16 |
|
| 17 |
xception_default_data_transforms_256 = {
|
| 18 |
'train': transforms.Compose([
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
def test_full_image_network(video_path, model_path, output_path,
|
| 118 |
+
start_frame=0, end_frame=None, cuda=False):
|
| 119 |
"""
|
| 120 |
Reads a video and evaluates a subset of frames with the a detection network
|
| 121 |
that takes in a full frame. Outputs are only given if a face is present
|
|
|
|
| 150 |
# Load model
|
| 151 |
# model, *_ = model_selection(modelname='xception', num_out_classes=2)
|
| 152 |
model = Two_Stream_Net()
|
| 153 |
+
model.load_state_dict(torch.load(model_path,map_location))
|
| 154 |
model = model.to(device)
|
| 155 |
model.eval()
|
| 156 |
|