Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import cv2
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
-
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
@@ -9,18 +10,24 @@ from torch.autograd import Variable
|
|
| 9 |
from torchvision import transforms
|
| 10 |
import torch.nn.functional as F
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
#
|
| 20 |
if not os.path.exists("saved_models"):
|
| 21 |
os.makedirs("saved_models")
|
| 22 |
-
|
| 23 |
-
# Automatically download isnet.pth from ECCV2022/dis-background-removal if not present
|
| 24 |
isnet_path = "saved_models/isnet.pth"
|
| 25 |
if not os.path.exists(isnet_path):
|
| 26 |
print("Downloading isnet.pth from ECCV2022/dis-background-removal...")
|
|
@@ -31,6 +38,13 @@ if not os.path.exists(isnet_path):
|
|
| 31 |
local_dir_use_symlinks=False
|
| 32 |
)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
class GOSNormalize(object):
|
| 35 |
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
| 36 |
self.mean = mean
|
|
@@ -60,21 +74,21 @@ def build_model(hypar, device):
|
|
| 60 |
def predict(net, inputs_val, shapes_val, hypar, device):
|
| 61 |
net.eval()
|
| 62 |
inputs_val = inputs_val.type(torch.FloatTensor).to(device)
|
| 63 |
-
with torch.no_grad():
|
| 64 |
inputs_val_v = Variable(inputs_val)
|
| 65 |
ds_val = net(inputs_val_v)[0]
|
| 66 |
-
pred_val = ds_val[0][0, :, :, :]
|
| 67 |
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
|
| 68 |
ma = torch.max(pred_val)
|
| 69 |
mi = torch.min(pred_val)
|
| 70 |
-
pred_val = (pred_val - mi) / (ma - mi)
|
| 71 |
return (pred_val.cpu().numpy() * 255).astype(np.uint8)
|
| 72 |
|
| 73 |
# Set Parameters
|
| 74 |
hypar = {
|
| 75 |
"model_path": "saved_models",
|
| 76 |
"restore_model": "isnet.pth",
|
| 77 |
-
"cache_size": [512, 512],
|
| 78 |
"input_size": [512, 512],
|
| 79 |
"crop_size": [512, 512],
|
| 80 |
"model": ISNetDIS()
|
|
|
|
| 1 |
import cv2
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
+
import requests # Added for GitHub downloads
|
| 5 |
+
from huggingface_hub import hf_hub_download # For Hugging Face download
|
| 6 |
from PIL import Image
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
|
|
|
| 10 |
from torchvision import transforms
|
| 11 |
import torch.nn.functional as F
|
| 12 |
|
| 13 |
+
# Automatically download required files
|
| 14 |
+
# 1. data_loader_cache.py from GitHub
|
| 15 |
+
if not os.path.exists("data_loader_cache.py"):
|
| 16 |
+
print("Downloading data_loader_cache.py...")
|
| 17 |
+
response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/IS-Net/data_loader_cache.py")
|
| 18 |
+
with open("data_loader_cache.py", "wb") as f:
|
| 19 |
+
f.write(response.content)
|
| 20 |
|
| 21 |
+
# 2. models.py from GitHub
|
| 22 |
+
if not os.path.exists("models.py"):
|
| 23 |
+
print("Downloading models.py...")
|
| 24 |
+
response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/IS-Net/models.py")
|
| 25 |
+
with open("models.py", "wb") as f:
|
| 26 |
+
f.write(response.content)
|
| 27 |
|
| 28 |
+
# 3. isnet.pth from ECCV2022/dis-background-removal
|
| 29 |
if not os.path.exists("saved_models"):
|
| 30 |
os.makedirs("saved_models")
|
|
|
|
|
|
|
| 31 |
isnet_path = "saved_models/isnet.pth"
|
| 32 |
if not os.path.exists(isnet_path):
|
| 33 |
print("Downloading isnet.pth from ECCV2022/dis-background-removal...")
|
|
|
|
| 38 |
local_dir_use_symlinks=False
|
| 39 |
)
|
| 40 |
|
| 41 |
+
# Project imports
|
| 42 |
+
from data_loader_cache import normalize, im_reader, im_preprocess
|
| 43 |
+
from models import *
|
| 44 |
+
|
| 45 |
+
# Helpers
|
| 46 |
+
device = 'cpu' # Free Hugging Face Space uses CPU
|
| 47 |
+
|
| 48 |
class GOSNormalize(object):
|
| 49 |
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
| 50 |
self.mean = mean
|
|
|
|
| 74 |
def predict(net, inputs_val, shapes_val, hypar, device):
|
| 75 |
net.eval()
|
| 76 |
inputs_val = inputs_val.type(torch.FloatTensor).to(device)
|
| 77 |
+
with torch.no_grad():
|
| 78 |
inputs_val_v = Variable(inputs_val)
|
| 79 |
ds_val = net(inputs_val_v)[0]
|
| 80 |
+
pred_val = ds_val[0][0, :, :, :]
|
| 81 |
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
|
| 82 |
ma = torch.max(pred_val)
|
| 83 |
mi = torch.min(pred_val)
|
| 84 |
+
pred_val = (pred_val - mi) / (ma - mi)
|
| 85 |
return (pred_val.cpu().numpy() * 255).astype(np.uint8)
|
| 86 |
|
| 87 |
# Set Parameters
|
| 88 |
hypar = {
|
| 89 |
"model_path": "saved_models",
|
| 90 |
"restore_model": "isnet.pth",
|
| 91 |
+
"cache_size": [512, 512],
|
| 92 |
"input_size": [512, 512],
|
| 93 |
"crop_size": [512, 512],
|
| 94 |
"model": ISNetDIS()
|