Ayesha-Majeed commited on
Commit
31a266c
·
verified ·
1 Parent(s): 88118b8

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +35 -35
binary_segmentation.py CHANGED
@@ -495,57 +495,57 @@ class BinarySegmenter:
495
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
496
  ])
497
 
498
- def _load_birefnet(self):
499
- """Load BiRefNet model (best accuracy, larger)"""
500
- try:
501
- from transformers import AutoModelForImageSegmentation
502
-
503
- self.model = AutoModelForImageSegmentation.from_pretrained(
504
- 'ZhengPeng7/BiRefNet',
505
- trust_remote_code=True,
506
- cache_dir=str(self.cache_dir),
507
- torch_dtype=torch.float32,
508
- low_cpu_mem_usage=False
509
- )
510
-
511
- self.transform = transforms.Compose([
512
- transforms.Resize((320, 320)),
513
- transforms.ToTensor(),
514
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
515
- ])
516
- except ImportError:
517
- raise ImportError("BiRefNet requires: pip install transformers")
518
-
519
  # def _load_birefnet(self):
520
  # """Load BiRefNet model (best accuracy, larger)"""
521
  # try:
522
  # from transformers import AutoModelForImageSegmentation
523
-
524
  # self.model = AutoModelForImageSegmentation.from_pretrained(
525
  # 'ZhengPeng7/BiRefNet',
526
  # trust_remote_code=True,
527
  # cache_dir=str(self.cache_dir),
528
- # torch_dtype=torch.float32, # ✅ Keep FP32 for CPU
529
  # low_cpu_mem_usage=False
530
  # )
531
-
532
- # # ✅ QUANTIZE to INT8 for CPU speedup
533
- # if DEVICE == "cpu":
534
-
535
- # self.model = torch.quantization.quantize_dynamic(
536
- # self.model,
537
- # {torch.nn.Linear, torch.nn.Conv2d},
538
- # dtype=torch.qint8
539
- # )
540
- # logger.info("✅ BiRefNet quantized to INT8")
541
-
542
  # self.transform = transforms.Compose([
543
- # transforms.Resize((1024, 1024)),
544
  # transforms.ToTensor(),
545
  # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
546
  # ])
547
  # except ImportError:
548
  # raise ImportError("BiRefNet requires: pip install transformers")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
  def _load_rmbg(self):
551
  """Load RMBG model (good balance)"""
 
495
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
496
  ])
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  # def _load_birefnet(self):
499
  # """Load BiRefNet model (best accuracy, larger)"""
500
  # try:
501
  # from transformers import AutoModelForImageSegmentation
502
+
503
  # self.model = AutoModelForImageSegmentation.from_pretrained(
504
  # 'ZhengPeng7/BiRefNet',
505
  # trust_remote_code=True,
506
  # cache_dir=str(self.cache_dir),
507
+ # torch_dtype=torch.float32,
508
  # low_cpu_mem_usage=False
509
  # )
510
+
 
 
 
 
 
 
 
 
 
 
511
  # self.transform = transforms.Compose([
512
+ # transforms.Resize((320, 320)),
513
  # transforms.ToTensor(),
514
  # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
515
  # ])
516
  # except ImportError:
517
  # raise ImportError("BiRefNet requires: pip install transformers")
518
+
519
+ def _load_birefnet(self):
520
+ """Load BiRefNet model (best accuracy, larger)"""
521
+ try:
522
+ from transformers import AutoModelForImageSegmentation
523
+
524
+ self.model = AutoModelForImageSegmentation.from_pretrained(
525
+ 'ZhengPeng7/BiRefNet',
526
+ trust_remote_code=True,
527
+ cache_dir=str(self.cache_dir),
528
+ torch_dtype=torch.float32, # ✅ Keep FP32 for CPU
529
+ low_cpu_mem_usage=False
530
+ )
531
+
532
+ # ✅ QUANTIZE to INT8 for CPU speedup
533
+ if DEVICE == "cpu":
534
+
535
+ self.model = torch.quantization.quantize_dynamic(
536
+ self.model,
537
+ {torch.nn.Linear, torch.nn.Conv2d},
538
+ dtype=torch.qint8
539
+ )
540
+ logger.info("✅ BiRefNet quantized to INT8")
541
+
542
+ self.transform = transforms.Compose([
543
+ transforms.Resize((1024, 1024)),
544
+ transforms.ToTensor(),
545
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
546
+ ])
547
+ except ImportError:
548
+ raise ImportError("BiRefNet requires: pip install transformers")
549
 
550
  def _load_rmbg(self):
551
  """Load RMBG model (good balance)"""