Janeka commited on
Commit
07d78f3
·
verified ·
1 Parent(s): 21c339e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -20
app.py CHANGED
@@ -1,35 +1,99 @@
1
  import os
2
- import subprocess
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  import cv2
7
  from PIL import Image
8
- from huggingface_hub import hf_hub_download
9
 
10
- # Clone and install BiRefNet if not present
11
  if not os.path.exists('BiRefNet'):
12
- subprocess.run(['git', 'clone', 'https://github.com/ZhengPeng7/BiRefNet.git'])
13
- os.chdir('BiRefNet')
14
- subprocess.run(['pip', 'install', '-e', '.'])
15
- os.chdir('..')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Download model weights
18
  if not os.path.exists('BiRefNet.pth'):
19
- hf_hub_download(repo_id="ZhengPeng7/BiRefNet", filename="BiRefNet.pth", local_dir=".")
 
 
 
 
20
 
21
- # Import after installation
22
- from BiRefNet.models.BiRefNet import BiRefNet
23
- from BiRefNet.utils.dataloader import test_dataset
24
 
25
- # Rest of your code remains the same...
26
- model = BiRefNet()
27
- model.load_state_dict(torch.load('BiRefNet.pth', map_location=torch.device('cpu')))
 
28
  model.eval()
29
 
30
- def remove_background(input_image):
31
- # Your existing implementation
32
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- interface = gr.Interface(...)
35
- interface.launch()
 
1
  import os
2
+ import sys
3
+ import requests
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
  import cv2
8
  from PIL import Image
9
+ from tqdm import tqdm
10
 
11
+ # Download BiRefNet repository
12
  if not os.path.exists('BiRefNet'):
13
+ os.system('git clone https://github.com/ZhengPeng7/BiRefNet.git')
14
+ sys.path.insert(0, 'BiRefNet')
15
+
16
+ # Download model weights with progress bar
17
+ def download_file(url, filename):
18
+ response = requests.get(url, stream=True)
19
+ total_size = int(response.headers.get('content-length', 0))
20
+
21
+ with open(filename, 'wb') as f, tqdm(
22
+ desc=filename,
23
+ total=total_size,
24
+ unit='iB',
25
+ unit_scale=True,
26
+ unit_divisor=1024,
27
+ ) as bar:
28
+ for data in response.iter_content(chunk_size=1024):
29
+ size = f.write(data)
30
+ bar.update(size)
31
 
 
32
  if not os.path.exists('BiRefNet.pth'):
33
+ print("Downloading model weights...")
34
+ download_file(
35
+ "https://github.com/ZhengPeng7/BiRefNet/releases/download/v1.0/BiRefNet.pth",
36
+ "BiRefNet.pth"
37
+ )
38
 
39
+ # Import model after setting up
40
+ from models.BiRefNet import BiRefNet
41
+ from utils.dataloader import test_dataset
42
 
43
+ # Initialize model
44
+ device = torch.device('cpu')
45
+ model = BiRefNet().to(device)
46
+ model.load_state_dict(torch.load('BiRefNet.pth', map_location=device))
47
  model.eval()
48
 
49
+ def process_image(input_image):
50
+ # Convert to numpy array
51
+ image = np.array(input_image)
52
+ original_size = image.shape[:2]
53
+
54
+ # Resize for CPU processing
55
+ processed_size = (320, 320) # Reduced size for CPU
56
+ image = cv2.resize(image, processed_size)
57
+
58
+ # Preprocess
59
+ image = test_dataset.preprocess(image)
60
+ image = torch.from_numpy(image).unsqueeze(0).to(device)
61
+
62
+ # Predict
63
+ with torch.no_grad():
64
+ pred = model(image)
65
+
66
+ # Post-process
67
+ mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
68
+ mask = cv2.resize(mask, original_size[::-1]) # Resize back to original
69
+
70
+ # Apply mask
71
+ result = cv2.bitwise_and(
72
+ np.array(input_image),
73
+ np.array(input_image),
74
+ mask=mask
75
+ )
76
+
77
+ return Image.fromarray(result), Image.fromarray(mask)
78
+
79
+ # Gradio interface
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# BiRefNet Background Remover (CPU)")
82
+ gr.Markdown("Works on CPU but may be slow (10-30 seconds per image)")
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ input_image = gr.Image(type="pil", label="Input Image")
87
+ submit = gr.Button("Remove Background")
88
+
89
+ with gr.Column():
90
+ output_image = gr.Image(type="pil", label="Result")
91
+ output_mask = gr.Image(type="pil", label="Mask")
92
+
93
+ submit.click(
94
+ fn=process_image,
95
+ inputs=input_image,
96
+ outputs=[output_image, output_mask]
97
+ )
98
 
99
+ demo.launch()