AkashKumarave commited on
Commit
9241860
·
verified ·
1 Parent(s): b145c13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import cv2
2
  import gradio as gr
3
  import os
4
- from huggingface_hub import hf_hub_download # Added for Hugging Face download
 
5
  from PIL import Image
6
  import numpy as np
7
  import torch
@@ -9,18 +10,24 @@ from torch.autograd import Variable
9
  from torchvision import transforms
10
  import torch.nn.functional as F
11
 
12
- # Project imports (assumes data_loader_cache.py and models.py are uploaded)
13
- from data_loader_cache import normalize, im_reader, im_preprocess
14
- from models import *
 
 
 
 
15
 
16
- # Helpers
17
- device = 'cpu' # Free Hugging Face Space uses CPU
 
 
 
 
18
 
19
- # Create directory for model weights
20
  if not os.path.exists("saved_models"):
21
  os.makedirs("saved_models")
22
-
23
- # Automatically download isnet.pth from ECCV2022/dis-background-removal if not present
24
  isnet_path = "saved_models/isnet.pth"
25
  if not os.path.exists(isnet_path):
26
  print("Downloading isnet.pth from ECCV2022/dis-background-removal...")
@@ -31,6 +38,13 @@ if not os.path.exists(isnet_path):
31
  local_dir_use_symlinks=False
32
  )
33
 
 
 
 
 
 
 
 
34
  class GOSNormalize(object):
35
  def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
36
  self.mean = mean
@@ -60,21 +74,21 @@ def build_model(hypar, device):
60
  def predict(net, inputs_val, shapes_val, hypar, device):
61
  net.eval()
62
  inputs_val = inputs_val.type(torch.FloatTensor).to(device)
63
- with torch.no_grad(): # Reduce memory usage
64
  inputs_val_v = Variable(inputs_val)
65
  ds_val = net(inputs_val_v)[0]
66
- pred_val = ds_val[0][0, :, :, :] # B x 1 x H x W
67
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
68
  ma = torch.max(pred_val)
69
  mi = torch.min(pred_val)
70
- pred_val = (pred_val - mi) / (ma - mi) # Normalize to [0, 1]
71
  return (pred_val.cpu().numpy() * 255).astype(np.uint8)
72
 
73
  # Set Parameters
74
  hypar = {
75
  "model_path": "saved_models",
76
  "restore_model": "isnet.pth",
77
- "cache_size": [512, 512], # Optimized for CPU
78
  "input_size": [512, 512],
79
  "crop_size": [512, 512],
80
  "model": ISNetDIS()
 
1
  import cv2
2
  import gradio as gr
3
  import os
4
+ import requests # Added for GitHub downloads
5
+ from huggingface_hub import hf_hub_download # For Hugging Face download
6
  from PIL import Image
7
  import numpy as np
8
  import torch
 
10
  from torchvision import transforms
11
  import torch.nn.functional as F
12
 
13
+ # Automatically download required files
14
+ # 1. data_loader_cache.py from GitHub
15
+ if not os.path.exists("data_loader_cache.py"):
16
+ print("Downloading data_loader_cache.py...")
17
+ response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/IS-Net/data_loader_cache.py")
18
+ with open("data_loader_cache.py", "wb") as f:
19
+ f.write(response.content)
20
 
21
+ # 2. models.py from GitHub
22
+ if not os.path.exists("models.py"):
23
+ print("Downloading models.py...")
24
+ response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/IS-Net/models.py")
25
+ with open("models.py", "wb") as f:
26
+ f.write(response.content)
27
 
28
+ # 3. isnet.pth from ECCV2022/dis-background-removal
29
  if not os.path.exists("saved_models"):
30
  os.makedirs("saved_models")
 
 
31
  isnet_path = "saved_models/isnet.pth"
32
  if not os.path.exists(isnet_path):
33
  print("Downloading isnet.pth from ECCV2022/dis-background-removal...")
 
38
  local_dir_use_symlinks=False
39
  )
40
 
41
+ # Project imports
42
+ from data_loader_cache import normalize, im_reader, im_preprocess
43
+ from models import *
44
+
45
+ # Helpers
46
+ device = 'cpu' # Free Hugging Face Space uses CPU
47
+
48
  class GOSNormalize(object):
49
  def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
50
  self.mean = mean
 
74
  def predict(net, inputs_val, shapes_val, hypar, device):
75
  net.eval()
76
  inputs_val = inputs_val.type(torch.FloatTensor).to(device)
77
+ with torch.no_grad():
78
  inputs_val_v = Variable(inputs_val)
79
  ds_val = net(inputs_val_v)[0]
80
+ pred_val = ds_val[0][0, :, :, :]
81
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
82
  ma = torch.max(pred_val)
83
  mi = torch.min(pred_val)
84
+ pred_val = (pred_val - mi) / (ma - mi)
85
  return (pred_val.cpu().numpy() * 255).astype(np.uint8)
86
 
87
  # Set Parameters
88
  hypar = {
89
  "model_path": "saved_models",
90
  "restore_model": "isnet.pth",
91
+ "cache_size": [512, 512],
92
  "input_size": [512, 512],
93
  "crop_size": [512, 512],
94
  "model": ISNetDIS()