Ayesha-Majeed commited on
Commit
e4ff367
·
verified ·
1 Parent(s): 1652e64

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +34 -3
binary_segmentation.py CHANGED
@@ -494,19 +494,50 @@ class BinarySegmenter:
494
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
495
  ])
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  def _load_birefnet(self):
498
  """Load BiRefNet model (best accuracy, larger)"""
499
  try:
500
  from transformers import AutoModelForImageSegmentation
501
-
502
  self.model = AutoModelForImageSegmentation.from_pretrained(
503
  'ZhengPeng7/BiRefNet',
504
  trust_remote_code=True,
505
  cache_dir=str(self.cache_dir),
506
- torch_dtype=torch.float32,
507
  low_cpu_mem_usage=False
508
  )
509
-
 
 
 
 
 
 
 
 
 
 
510
  self.transform = transforms.Compose([
511
  transforms.Resize((320, 320)),
512
  transforms.ToTensor(),
 
494
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
495
  ])
496
 
497
+ # def _load_birefnet(self):
498
+ # """Load BiRefNet model (best accuracy, larger)"""
499
+ # try:
500
+ # from transformers import AutoModelForImageSegmentation
501
+
502
+ # self.model = AutoModelForImageSegmentation.from_pretrained(
503
+ # 'ZhengPeng7/BiRefNet',
504
+ # trust_remote_code=True,
505
+ # cache_dir=str(self.cache_dir),
506
+ # torch_dtype=torch.float32,
507
+ # low_cpu_mem_usage=False
508
+ # )
509
+
510
+ # self.transform = transforms.Compose([
511
+ # transforms.Resize((320, 320)),
512
+ # transforms.ToTensor(),
513
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
514
+ # ])
515
+ # except ImportError:
516
+ # raise ImportError("BiRefNet requires: pip install transformers")
517
+
518
  def _load_birefnet(self):
519
  """Load BiRefNet model (best accuracy, larger)"""
520
  try:
521
  from transformers import AutoModelForImageSegmentation
522
+
523
  self.model = AutoModelForImageSegmentation.from_pretrained(
524
  'ZhengPeng7/BiRefNet',
525
  trust_remote_code=True,
526
  cache_dir=str(self.cache_dir),
527
+ torch_dtype=torch.float32, # ✅ Keep FP32 for CPU
528
  low_cpu_mem_usage=False
529
  )
530
+
531
+ # ✅ QUANTIZE to INT8 for CPU speedup
532
+ if DEVICE == "cpu":
533
+ import torch.quantization
534
+ self.model = torch.quantization.quantize_dynamic(
535
+ self.model,
536
+ {torch.nn.Linear, torch.nn.Conv2d},
537
+ dtype=torch.qint8
538
+ )
539
+ logger.info("✅ BiRefNet quantized to INT8")
540
+
541
  self.transform = transforms.Compose([
542
  transforms.Resize((320, 320)),
543
  transforms.ToTensor(),