Commit
·
1e0fb18
1
Parent(s):
63c51ee
Add inference endpoint feature.
Browse files- handler.py +4 -5
handler.py
CHANGED
|
@@ -75,9 +75,7 @@ usage_to_weights_file = {
|
|
| 75 |
}
|
| 76 |
|
| 77 |
usage = 'General'
|
| 78 |
-
|
| 79 |
-
birefnet.to(device)
|
| 80 |
-
birefnet.eval()
|
| 81 |
|
| 82 |
# Set resolution
|
| 83 |
if usage in ['General-Lite-2K']:
|
|
@@ -91,9 +89,10 @@ else:
|
|
| 91 |
class EndpointHandler():
|
| 92 |
def __init__(self, path=""):
|
| 93 |
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 94 |
-
|
| 95 |
)
|
| 96 |
self.birefnet.to(device)
|
|
|
|
| 97 |
|
| 98 |
def __call__(self, data: Dict[str, Any]):
|
| 99 |
"""
|
|
@@ -123,7 +122,7 @@ class EndpointHandler():
|
|
| 123 |
|
| 124 |
# Prediction
|
| 125 |
with torch.no_grad():
|
| 126 |
-
preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
| 127 |
pred = preds[0].squeeze()
|
| 128 |
|
| 129 |
# Show Results
|
|
|
|
| 75 |
}
|
| 76 |
|
| 77 |
usage = 'General'
|
| 78 |
+
model_repo = '/'.join(('zhengpeng7', usage_to_weights_file[usage]))
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# Set resolution
|
| 81 |
if usage in ['General-Lite-2K']:
|
|
|
|
| 89 |
class EndpointHandler():
|
| 90 |
def __init__(self, path=""):
|
| 91 |
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 92 |
+
model_repo, trust_remote_code=True
|
| 93 |
)
|
| 94 |
self.birefnet.to(device)
|
| 95 |
+
self.birefnet.eval()
|
| 96 |
|
| 97 |
def __call__(self, data: Dict[str, Any]):
|
| 98 |
"""
|
|
|
|
| 122 |
|
| 123 |
# Prediction
|
| 124 |
with torch.no_grad():
|
| 125 |
+
preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
| 126 |
pred = preds[0].squeeze()
|
| 127 |
|
| 128 |
# Show Results
|