confidence threshold
Browse files
app.py
CHANGED
|
@@ -2,21 +2,21 @@ import numpy as np
|
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
from infer import detections
|
|
|
|
| 5 |
import os
|
| 6 |
os.system("mkdir data")
|
| 7 |
os.system("mkdir data/models")
|
| 8 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
|
| 9 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
|
| 10 |
'''
|
| 11 |
-
|
| 12 |
-
def walt_demo(input_img):
|
| 13 |
#detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
|
| 14 |
if torch.cuda.is_available() == False:
|
| 15 |
device='cpu'
|
| 16 |
else:
|
| 17 |
device='cuda:0'
|
| 18 |
#detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
|
| 19 |
-
detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=
|
| 20 |
|
| 21 |
count = 0
|
| 22 |
#img = detect_people.run_on_image(input_img)
|
|
@@ -45,9 +45,9 @@ article="""
|
|
| 45 |
"""
|
| 46 |
|
| 47 |
examples = [
|
| 48 |
-
'demo/images/img_1.jpg',
|
| 49 |
-
'demo/images/img_2.jpg',
|
| 50 |
-
'demo/images/img_4.png',
|
| 51 |
]
|
| 52 |
|
| 53 |
'''
|
|
@@ -58,9 +58,15 @@ img=walt_demo(img)
|
|
| 58 |
cv2.imwrite(filename.replace('/images/','/results/'),img)
|
| 59 |
cv2.imwrite('check.png',img)
|
| 60 |
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
demo = gr.Interface(walt_demo,
|
| 62 |
-
|
| 63 |
-
|
| 64 |
article=article,
|
| 65 |
title=title,
|
| 66 |
enable_queue=True,
|
|
|
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
from infer import detections
|
| 5 |
+
'''
|
| 6 |
import os
|
| 7 |
os.system("mkdir data")
|
| 8 |
os.system("mkdir data/models")
|
| 9 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
|
| 10 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
|
| 11 |
'''
|
| 12 |
+
def walt_demo(input_img, confidence_threshold):
|
|
|
|
| 13 |
#detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
|
| 14 |
if torch.cuda.is_available() == False:
|
| 15 |
device='cpu'
|
| 16 |
else:
|
| 17 |
device='cuda:0'
|
| 18 |
#detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
|
| 19 |
+
detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=confidence_threshold)
|
| 20 |
|
| 21 |
count = 0
|
| 22 |
#img = detect_people.run_on_image(input_img)
|
|
|
|
| 45 |
"""
|
| 46 |
|
| 47 |
examples = [
|
| 48 |
+
['demo/images/img_1.jpg',0.8],
|
| 49 |
+
['demo/images/img_2.jpg',0.8],
|
| 50 |
+
['demo/images/img_4.png',0.85],
|
| 51 |
]
|
| 52 |
|
| 53 |
'''
|
|
|
|
| 58 |
cv2.imwrite(filename.replace('/images/','/results/'),img)
|
| 59 |
cv2.imwrite('check.png',img)
|
| 60 |
'''
|
| 61 |
+
confidence_threshold = gr.Slider(minimum=0.3,
|
| 62 |
+
maximum=1.0,
|
| 63 |
+
step=0.01,
|
| 64 |
+
value=1.0,
|
| 65 |
+
label="Amodal Detection Confidence Threshold")
|
| 66 |
+
inputs = [gr.Image(), confidence_threshold]
|
| 67 |
demo = gr.Interface(walt_demo,
|
| 68 |
+
outputs="image",
|
| 69 |
+
inputs=inputs,
|
| 70 |
article=article,
|
| 71 |
title=title,
|
| 72 |
enable_queue=True,
|