Spaces:
Runtime error
Runtime error
baixintech_zhangyiming_prod
commited on
Commit
·
53a3db7
1
Parent(s):
ad6f6d7
output with softmax
Browse files- app.py +8 -5
- wmdetection/pipelines/predictor.py +6 -0
app.py
CHANGED
|
@@ -12,13 +12,16 @@ model, transforms = get_watermarks_detection_model(
|
|
| 12 |
predictor = WatermarksPredictor(model, transforms, 'cpu')
|
| 13 |
|
| 14 |
|
| 15 |
-
def predict(image):
|
| 16 |
-
result = predictor.
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
examples = glob.glob(os.path.join('images', 'clean', '*'))
|
| 21 |
examples.extend(glob.glob(os.path.join('images', 'watermark', '*')))
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
iface.launch()
|
|
|
|
| 12 |
predictor = WatermarksPredictor(model, transforms, 'cpu')
|
| 13 |
|
| 14 |
|
| 15 |
+
def predict(image, threshold=0.5):
|
| 16 |
+
result = predictor.predict_image_confidence(image)
|
| 17 |
+
values = result.tolist()
|
| 18 |
+
wm_flag = 1 if values[1] >= threshold else 0
|
| 19 |
+
return 'watermarked' if wm_flag else 'clean', "%.4f" % values[1] # prints "watermarked"
|
| 20 |
|
| 21 |
|
| 22 |
examples = glob.glob(os.path.join('images', 'clean', '*'))
|
| 23 |
examples.extend(glob.glob(os.path.join('images', 'watermark', '*')))
|
| 24 |
+
examples = [[e, 0.5] for e in examples]
|
| 25 |
+
iface = gr.Interface(fn=predict, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Number(label="threshold", default=0.5), ],
|
| 26 |
+
examples=examples, outputs=[gr.outputs.Textbox(label="class"), gr.outputs.Textbox(label="wm_confidence")])
|
| 27 |
iface.launch()
|
wmdetection/pipelines/predictor.py
CHANGED
|
@@ -51,6 +51,12 @@ class WatermarksPredictor:
|
|
| 51 |
outputs = self.wm_model(input_img.to(self.device))
|
| 52 |
result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
|
| 53 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def run(self, files, num_workers=8, bs=8, pbar=True):
|
| 56 |
eval_dataset = ImageDataset(files, self.classifier_transforms)
|
|
|
|
| 51 |
outputs = self.wm_model(input_img.to(self.device))
|
| 52 |
result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
|
| 53 |
return result
|
| 54 |
+
|
| 55 |
+
def predict_image_confidence(self, pil_image):
|
| 56 |
+
pil_image = pil_image.convert("RGB")
|
| 57 |
+
input_img = self.classifier_transforms(pil_image).float().unsqueeze(0)
|
| 58 |
+
outputs = self.wm_model(input_img.to(self.device))
|
| 59 |
+
return torch.nn.functional.softmax(outputs, dim=1).cpu().reshape(-1)
|
| 60 |
|
| 61 |
def run(self, files, num_workers=8, bs=8, pbar=True):
|
| 62 |
eval_dataset = ImageDataset(files, self.classifier_transforms)
|