Commit
·
bca1df1
1
Parent(s):
ad04501
Update codes if handler.py.
Browse files- handler.py +7 -1
handler.py
CHANGED
|
@@ -62,6 +62,7 @@ class ImagePreprocessor():
|
|
| 62 |
|
| 63 |
usage_to_weights_file = {
|
| 64 |
'General': 'BiRefNet',
|
|
|
|
| 65 |
'General-Lite': 'BiRefNet_lite',
|
| 66 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
| 67 |
'General-reso_512': 'BiRefNet-reso_512',
|
|
@@ -82,9 +83,12 @@ if usage in ['General-Lite-2K']:
|
|
| 82 |
resolution = (2560, 1440)
|
| 83 |
elif usage in ['General-reso_512']:
|
| 84 |
resolution = (512, 512)
|
|
|
|
|
|
|
| 85 |
else:
|
| 86 |
resolution = (1024, 1024)
|
| 87 |
|
|
|
|
| 88 |
|
| 89 |
class EndpointHandler():
|
| 90 |
def __init__(self, path=''):
|
|
@@ -93,6 +97,8 @@ class EndpointHandler():
|
|
| 93 |
)
|
| 94 |
self.birefnet.to(device)
|
| 95 |
self.birefnet.eval()
|
|
|
|
|
|
|
| 96 |
|
| 97 |
def __call__(self, data: Dict[str, Any]):
|
| 98 |
"""
|
|
@@ -122,7 +128,7 @@ class EndpointHandler():
|
|
| 122 |
|
| 123 |
# Prediction
|
| 124 |
with torch.no_grad():
|
| 125 |
-
preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
| 126 |
pred = preds[0].squeeze()
|
| 127 |
|
| 128 |
# Show Results
|
|
|
|
| 62 |
|
| 63 |
usage_to_weights_file = {
|
| 64 |
'General': 'BiRefNet',
|
| 65 |
+
'General-HR': 'BiRefNet_HR',
|
| 66 |
'General-Lite': 'BiRefNet_lite',
|
| 67 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
| 68 |
'General-reso_512': 'BiRefNet-reso_512',
|
|
|
|
| 83 |
resolution = (2560, 1440)
|
| 84 |
elif usage in ['General-reso_512']:
|
| 85 |
resolution = (512, 512)
|
| 86 |
+
elif usage in ['General-HR']:
|
| 87 |
+
resolution = (2048, 2048)
|
| 88 |
else:
|
| 89 |
resolution = (1024, 1024)
|
| 90 |
|
| 91 |
+
half_precision = True
|
| 92 |
|
| 93 |
class EndpointHandler():
|
| 94 |
def __init__(self, path=''):
|
|
|
|
| 97 |
)
|
| 98 |
self.birefnet.to(device)
|
| 99 |
self.birefnet.eval()
|
| 100 |
+
if half_precision:
|
| 101 |
+
self.birefnet.half()
|
| 102 |
|
| 103 |
def __call__(self, data: Dict[str, Any]):
|
| 104 |
"""
|
|
|
|
| 128 |
|
| 129 |
# Prediction
|
| 130 |
with torch.no_grad():
|
| 131 |
+
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
|
| 132 |
pred = preds[0].squeeze()
|
| 133 |
|
| 134 |
# Show Results
|