mastari commited on
Commit
2792e1a
·
1 Parent(s): edb5977
Files changed (3) hide show
  1. .gitignore +142 -0
  2. BiRefNet_config.py +11 -0
  3. handler.py +33 -3
.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom
2
+ e_*
3
+ .vscode
4
+ ckpt
5
+ preds
6
+ evaluation/eval-*
7
+ nohup.out*
8
+ tmp*
9
+ *.pth
10
+ core-*-python-*
11
+ .DS_Store
12
+ __MACOSX/
13
+
14
+ # Byte-compiled / optimized / DLL files
15
+ __pycache__/
16
+ *.py[cod]
17
+ *$py.class
18
+
19
+ # C extensions
20
+ *.so
21
+
22
+ # Distribution / packaging
23
+ .Python
24
+ build/
25
+ develop-eggs/
26
+ dist/
27
+ downloads/
28
+ eggs/
29
+ .eggs/
30
+ lib/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ pip-wheel-metadata/
37
+ share/python-wheels/
38
+ *.egg-info/
39
+ .installed.cfg
40
+ *.egg
41
+ MANIFEST
42
+
43
+ # PyInstaller
44
+ # Usually these files are written by a python script from a template
45
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
+ *.manifest
47
+ *.spec
48
+
49
+ # Installer logs
50
+ pip-log.txt
51
+ pip-delete-this-directory.txt
52
+
53
+ # Unit test / coverage reports
54
+ htmlcov/
55
+ .tox/
56
+ .nox/
57
+ .coverage
58
+ .coverage.*
59
+ .cache
60
+ nosetests.xml
61
+ coverage.xml
62
+ *.cover
63
+ *.py,cover
64
+ .hypothesis/
65
+ .pytest_cache/
66
+
67
+ # Translations
68
+ *.mo
69
+ *.pot
70
+
71
+ # Django stuff:
72
+ *.log
73
+ local_settings.py
74
+ db.sqlite3
75
+ db.sqlite3-journal
76
+
77
+ # Flask stuff:
78
+ instance/
79
+ .webassets-cache
80
+
81
+ # Scrapy stuff:
82
+ .scrapy
83
+
84
+ # Sphinx documentation
85
+ docs/_build/
86
+
87
+ # PyBuilder
88
+ target/
89
+
90
+ # Jupyter Notebook
91
+ .ipynb_checkpoints
92
+
93
+ # IPython
94
+ profile_default/
95
+ ipython_config.py
96
+
97
+ # pyenv
98
+ .python-version
99
+
100
+ # pipenv
101
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
+ # install all needed dependencies.
105
+ #Pipfile.lock
106
+
107
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108
+ __pypackages__/
109
+
110
+ # Celery stuff
111
+ celerybeat-schedule
112
+ celerybeat.pid
113
+
114
+ # SageMath parsed files
115
+ *.sage.py
116
+
117
+ # Environments
118
+ .env
119
+ .venv
120
+ env/
121
+ venv/
122
+ ENV/
123
+ env.bak/
124
+ venv.bak/
125
+
126
+ # Spyder project settings
127
+ .spyderproject
128
+ .spyproject
129
+
130
+ # Rope project settings
131
+ .ropeproject
132
+
133
+ # mkdocs documentation
134
+ /site
135
+
136
+ # mypy
137
+ .mypy_cache/
138
+ .dmypy.json
139
+ dmypy.json
140
+
141
+ # Pyre type checker
142
+ .pyre/
BiRefNet_config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BiRefNetConfig(PretrainedConfig):
4
+ model_type = "SegformerForSemanticSegmentation"
5
+ def __init__(
6
+ self,
7
+ bb_pretrained=False,
8
+ **kwargs
9
+ ):
10
+ self.bb_pretrained = bb_pretrained
11
+ super().__init__(**kwargs)
handler.py CHANGED
@@ -1,7 +1,7 @@
1
  # handler.py — BiRefNet endpoint handler
2
  # Fully instrumented for debugging input structure and format.
3
 
4
- from typing import Dict, Any, Tuple
5
  import os
6
  import io
7
  import base64
@@ -79,6 +79,25 @@ usage_to_weights_file = {
79
  usage = "General"
80
  resolution = (1024, 1024)
81
  half_precision = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # ======================================================
84
  # Endpoint Handler
@@ -156,12 +175,23 @@ class EndpointHandler:
156
 
157
  pred = preds[0].squeeze()
158
  pred_pil = transforms.ToPILImage()(pred)
 
 
159
 
160
  image_masked = refine_foreground(image, pred_pil)
161
- image_masked.putalpha(pred_pil.resize(image.size))
162
 
163
  buffer = io.BytesIO()
164
  image_masked.save(buffer, format="PNG")
165
  encoded_result = base64.b64encode(buffer.getvalue()).decode("utf-8")
166
- return {"image_base64": encoded_result}
167
 
 
 
 
 
 
 
 
 
 
 
 
1
  # handler.py — BiRefNet endpoint handler
2
  # Fully instrumented for debugging input structure and format.
3
 
4
+ from typing import Dict, Any, Tuple, Optional
5
  import os
6
  import io
7
  import base64
 
79
  usage = "General"
80
  resolution = (1024, 1024)
81
  half_precision = True
82
+ SEGMENTATION_THRESHOLD = 0.05
83
+
84
+
85
+ def extract_bbox_from_mask(mask: Image.Image, threshold: float = SEGMENTATION_THRESHOLD) -> Optional[Dict[str, int]]:
86
+ """Compute a bounding box for the non-zero region of the mask."""
87
+ mask_gray = mask.convert("L")
88
+ mask_array = np.array(mask_gray, dtype=np.float32) / 255.0
89
+ binary = mask_array > threshold
90
+ if not np.any(binary):
91
+ return None
92
+ ys, xs = np.where(binary)
93
+ x_min, x_max = xs.min(), xs.max()
94
+ y_min, y_max = ys.min(), ys.max()
95
+ return {
96
+ "x": int(x_min),
97
+ "y": int(y_min),
98
+ "width": int(x_max - x_min + 1),
99
+ "height": int(y_max - y_min + 1),
100
+ }
101
 
102
  # ======================================================
103
  # Endpoint Handler
 
175
 
176
  pred = preds[0].squeeze()
177
  pred_pil = transforms.ToPILImage()(pred)
178
+ mask_resized = pred_pil.resize(image.size)
179
+ mask_bbox = extract_bbox_from_mask(mask_resized)
180
 
181
  image_masked = refine_foreground(image, pred_pil)
182
+ image_masked.putalpha(mask_resized)
183
 
184
  buffer = io.BytesIO()
185
  image_masked.save(buffer, format="PNG")
186
  encoded_result = base64.b64encode(buffer.getvalue()).decode("utf-8")
 
187
 
188
+ mask_buffer = io.BytesIO()
189
+ mask_resized.save(mask_buffer, format="PNG")
190
+ encoded_mask = base64.b64encode(mask_buffer.getvalue()).decode("utf-8")
191
+
192
+ return {
193
+ "image_base64": encoded_result,
194
+ "mask_base64": encoded_mask,
195
+ "mask_bbox": mask_bbox,
196
+ "mask_size": {"width": mask_resized.width, "height": mask_resized.height},
197
+ }