Update birefnet.py
Browse files- birefnet.py +15 -2
birefnet.py
CHANGED
|
@@ -1993,9 +1993,22 @@ def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(
|
|
| 1993 |
return image
|
| 1994 |
|
| 1995 |
class BiRefNet(
|
| 1996 |
-
|
| 1997 |
-
):
|
| 1998 |
config_class = BiRefNetConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1999 |
def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
|
| 2000 |
super(BiRefNet, self).__init__(config)
|
| 2001 |
bb_pretrained = config.bb_pretrained
|
|
|
|
| 1993 |
return image
|
| 1994 |
|
| 1995 |
class BiRefNet(
|
| 1996 |
+
PreTrainedModel
|
| 1997 |
+
):
|
| 1998 |
config_class = BiRefNetConfig
|
| 1999 |
+
|
| 2000 |
+
|
| 2001 |
+
@property
|
| 2002 |
+
def all_tied_weights_keys(self):
|
| 2003 |
+
keys = getattr(self, "_tied_weights_keys", None)
|
| 2004 |
+
if keys is None:
|
| 2005 |
+
return {}
|
| 2006 |
+
if isinstance(keys, dict):
|
| 2007 |
+
return keys
|
| 2008 |
+
try:
|
| 2009 |
+
return {k: None for k in keys}
|
| 2010 |
+
except TypeError:
|
| 2011 |
+
return {}
|
| 2012 |
def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
|
| 2013 |
super(BiRefNet, self).__init__(config)
|
| 2014 |
bb_pretrained = config.bb_pretrained
|