append probablities to the output
Browse files- demo.py +7 -1
- opt.py +8 -5
- requirements.txt +1 -0
demo.py
CHANGED
|
@@ -31,6 +31,11 @@ def demo(folder_path, output_path=Path("tmp")):
|
|
| 31 |
image = image.to(opt.device).unsqueeze(0)
|
| 32 |
outputs = model(image, seg_size=image_size)
|
| 33 |
out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
overlay = draw_segmentation_masks(
|
| 36 |
dsm_image, masks=out_map[0, ...] > opt.mask_threshold
|
|
@@ -43,7 +48,8 @@ def demo(folder_path, output_path=Path("tmp")):
|
|
| 43 |
],
|
| 44 |
padding=5,
|
| 45 |
)
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
if __name__ == "__main__":
|
|
|
|
| 31 |
image = image.to(opt.device).unsqueeze(0)
|
| 32 |
outputs = model(image, seg_size=image_size)
|
| 33 |
out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
|
| 34 |
+
pred = outputs["ensemble"]["out_map"].max().item()
|
| 35 |
+
if pred > opt.mask_threshold:
|
| 36 |
+
print(f"Found manipulation in {image_path.name}")
|
| 37 |
+
else:
|
| 38 |
+
print(f"No manipulation found in {image_path.name}")
|
| 39 |
|
| 40 |
overlay = draw_segmentation_masks(
|
| 41 |
dsm_image, masks=out_map[0, ...] > opt.mask_threshold
|
|
|
|
| 48 |
],
|
| 49 |
padding=5,
|
| 50 |
)
|
| 51 |
+
image_name = image_path.stem + f"-{pred:.2f}" + image_path.suffix
|
| 52 |
+
save_image(grid_image, (output_path / image_name).as_posix())
|
| 53 |
|
| 54 |
|
| 55 |
if __name__ == "__main__":
|
opt.py
CHANGED
|
@@ -10,8 +10,8 @@ import yaml
|
|
| 10 |
from termcolor import cprint
|
| 11 |
|
| 12 |
|
| 13 |
-
def load_dataset_arguments(opt):
|
| 14 |
-
if opt.load is None:
|
| 15 |
return
|
| 16 |
|
| 17 |
# exclude parameters assigned in the command
|
|
@@ -24,7 +24,10 @@ def load_dataset_arguments(opt):
|
|
| 24 |
arguments = []
|
| 25 |
|
| 26 |
# load parameters in the yaml file
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
with open(opt.load, "r") as f:
|
| 29 |
yaml_arguments = yaml.safe_load(f)
|
| 30 |
# TODO this should be verified
|
|
@@ -33,7 +36,7 @@ def load_dataset_arguments(opt):
|
|
| 33 |
setattr(opt, k, v)
|
| 34 |
|
| 35 |
|
| 36 |
-
def get_opt(additional_parsers: Optional[List] = None):
|
| 37 |
parents = [get_arguments_parser()]
|
| 38 |
if additional_parsers:
|
| 39 |
parents.extend(additional_parsers)
|
|
@@ -43,7 +46,7 @@ def get_opt(additional_parsers: Optional[List] = None):
|
|
| 43 |
opt = parser.parse_known_args()[0]
|
| 44 |
|
| 45 |
# load dataset argument file
|
| 46 |
-
load_dataset_arguments(opt)
|
| 47 |
|
| 48 |
# user-defined warnings and assertions
|
| 49 |
if opt.decoder.lower() not in ["c1"]:
|
|
|
|
| 10 |
from termcolor import cprint
|
| 11 |
|
| 12 |
|
| 13 |
+
def load_dataset_arguments(cfg_path, opt):
|
| 14 |
+
if opt.load is None and cfg_path is None:
|
| 15 |
return
|
| 16 |
|
| 17 |
# exclude parameters assigned in the command
|
|
|
|
| 24 |
arguments = []
|
| 25 |
|
| 26 |
# load parameters in the yaml file
|
| 27 |
+
if cfg_path is not None:
|
| 28 |
+
opt.load = cfg_path
|
| 29 |
+
else:
|
| 30 |
+
assert os.path.exists(opt.load)
|
| 31 |
with open(opt.load, "r") as f:
|
| 32 |
yaml_arguments = yaml.safe_load(f)
|
| 33 |
# TODO this should be verified
|
|
|
|
| 36 |
setattr(opt, k, v)
|
| 37 |
|
| 38 |
|
| 39 |
+
def get_opt(cfg_path: Optional[str] = None, additional_parsers: Optional[List] = None):
|
| 40 |
parents = [get_arguments_parser()]
|
| 41 |
if additional_parsers:
|
| 42 |
parents.extend(additional_parsers)
|
|
|
|
| 46 |
opt = parser.parse_known_args()[0]
|
| 47 |
|
| 48 |
# load dataset argument file
|
| 49 |
+
load_dataset_arguments(cfg_path, opt)
|
| 50 |
|
| 51 |
# user-defined warnings and assertions
|
| 52 |
if opt.decoder.lower() not in ["c1"]:
|
requirements.txt
CHANGED
|
@@ -26,3 +26,4 @@ timm==0.9.12
|
|
| 26 |
torch==1.12.1+cu116
|
| 27 |
torchvision==0.13.1+cu116
|
| 28 |
tqdm==4.64.1
|
|
|
|
|
|
| 26 |
torch==1.12.1+cu116
|
| 27 |
torchvision==0.13.1+cu116
|
| 28 |
tqdm==4.64.1
|
| 29 |
+
markupsafe==2.0.1
|