Stable-X commited on
Commit
c6ebb49
·
verified ·
1 Parent(s): 65a7eee

Update trellis/pipelines/trellis_image_to_3d.py

Browse files
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -231,14 +231,17 @@ class TrellisImageTo3DPipeline(Pipeline):
231
  if scale < 1:
232
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
233
 
234
- # Get mask using BiRefNet
235
- mask = self._get_birefnet_mask(input)
236
 
237
- # Convert input to RGBA and apply mask
238
- input_rgba = input.convert('RGBA')
239
- input_array = np.array(input_rgba)
240
- input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel
241
- output = Image.fromarray(input_array)
 
 
 
242
 
243
  # Process the output image
244
  output_np = np.array(output)
@@ -789,11 +792,11 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
789
  del new_pipeline.VGGT_model.point_head
790
  new_pipeline.VGGT_model.eval()
791
 
792
- new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
793
- 'ZhengPeng7/BiRefNet',
794
- trust_remote_code=True
795
- ).cpu()
796
- new_pipeline.birefnet_model.eval()
797
 
798
  new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
799
  new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
 
231
  if scale < 1:
232
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
233
 
234
+ # # Get mask using BiRefNet
235
+ # mask = self._get_birefnet_mask(input)
236
 
237
+ # # Convert input to RGBA and apply mask
238
+ # input_rgba = input.convert('RGBA')
239
+ # input_array = np.array(input_rgba)
240
+ # input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel
241
+ # output = Image.fromarray(input_array)
242
+ if getattr(self, 'rembg_session', None) is None:
243
+ self.rembg_session = rembg.new_session('u2net')
244
+ output = rembg.remove(input, session=self.rembg_session)
245
 
246
  # Process the output image
247
  output_np = np.array(output)
 
792
  del new_pipeline.VGGT_model.point_head
793
  new_pipeline.VGGT_model.eval()
794
 
795
+ # new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
796
+ # 'ZhengPeng7/BiRefNet',
797
+ # trust_remote_code=True
798
+ # ).cpu()
799
+ # new_pipeline.birefnet_model.eval()
800
 
801
  new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
802
  new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']