AkashKumarave commited on
Commit
6ecfeba
·
verified ·
1 Parent(s): 0d3d566

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitattributes +33 -35
  2. README.md +16 -12
  3. app.py +153 -201
  4. requirements.txt +8 -11
  5. robot.png +3 -0
  6. ship.png +3 -0
.gitattributes CHANGED
@@ -1,35 +1,33 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ robot.png filter=lfs diff=lfs merge=lfs -text
33
+ ship.png filter=lfs diff=lfs merge=lfs -text
 
 
README.md CHANGED
@@ -1,12 +1,16 @@
1
- ---
2
- title: Testing
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.20.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
+ ---
2
+ title: DIS Background Removal
3
+ emoji: 🔥 🌠 🏰
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.23.1
8
+ python_version: 3.11.10
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ models:
13
+ - doevent/dis
14
+ ---
15
+
16
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,201 +1,153 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
- from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel
5
- import logging
6
- import torch
7
- from diffusers import StableDiffusionImg2ImgPipeline, EulerDiscreteScheduler
8
- from PIL import Image
9
- import numpy as np
10
- import cv2
11
- import mediapipe as mp
12
- import base64
13
- from io import BytesIO
14
- import uvicorn
15
-
16
- # Set up logging
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
-
20
- app = FastAPI()
21
-
22
- # Add CORS middleware to allow requests from Framer
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"], # In production, restrict to your Framer domain
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
-
31
- pipe = None
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
-
34
- def initialize_pipeline():
35
- global pipe
36
- try:
37
- logger.info("Starting pipeline initialization...")
38
- model_id = "runwayml/stable-diffusion-v1-5"
39
- scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
40
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
41
- model_id,
42
- scheduler=scheduler,
43
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
44
- low_cpu_mem_usage=True
45
- )
46
- pipe = pipe.to(device)
47
- logger.info(f"Stable Diffusion pipeline initialized successfully on {device}.")
48
- except Exception as e:
49
- logger.error(f"Failed to initialize pipeline: {str(e)}", exc_info=True)
50
- raise
51
-
52
- try:
53
- initialize_pipeline()
54
- except Exception:
55
- logger.error("Application failed to start due to pipeline initialization error.")
56
- raise
57
-
58
- mp_face_detection = mp.solutions.face_detection
59
- face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
60
-
61
- def extract_face(image: Image.Image):
62
- try:
63
- logger.info("Starting face extraction.")
64
- img_np = np.array(image)
65
- img_rgb = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
66
- results = face_detection.process(img_rgb)
67
- if not results.detections:
68
- logger.warning("No face detected in the reference image.")
69
- raise HTTPException(status_code=400, detail="No face detected in the reference image.")
70
-
71
- detection = results.detections[0]
72
- bbox = detection.location_data.relative_bounding_box
73
- ih, iw, _ = img_np.shape
74
- x, y, w, h = int(bbox.xmin * iw), int(bbox.ymin * ih), int(bbox.width * iw), int(bbox.height * ih)
75
- x, y = max(0, x), max(0, y)
76
- w, h = min(w, iw - x), min(h, ih - y)
77
- face = img_np[y:y+h, x:x+w]
78
- logger.info("Face extracted successfully.")
79
- return Image.fromarray(face), (x, y, w, h), image.size
80
- except Exception as e:
81
- logger.error(f"Error in extract_face: {str(e)}", exc_info=True)
82
- raise HTTPException(status_code=500, detail=f"Face extraction failed: {str(e)}")
83
-
84
- def scale_coords(coords, original_size, new_size=(512, 512)):
85
- orig_w, orig_h = original_size
86
- new_w, new_h = new_size
87
- x, y, w, h = coords
88
- scale_x, scale_y = new_w / orig_w, new_h / orig_h
89
- return (int(x * scale_x), int(y * scale_y), int(w * scale_x), int(h * scale_y))
90
-
91
- def overlay_face(generated_img, face_img, face_coords):
92
- try:
93
- logger.info("Starting face overlay.")
94
- gen_np = np.array(generated_img)
95
- x, y, w, h = face_coords
96
- face_np = np.array(face_img.resize((w, h)))
97
- gen_bgr = cv2.cvtColor(gen_np, cv2.COLOR_RGB2BGR)
98
- mask = np.zeros((h, w), dtype=np.uint8)
99
- cv2.rectangle(mask, (0, 0), (w, h), 255, -1)
100
- center = (x + w // 2, y + h // 2)
101
- face_bgr = cv2.cvtColor(face_np, cv2.COLOR_RGB2BGR)
102
- result_bgr = cv2.seamlessClone(face_bgr, gen_bgr, mask, center, cv2.NORMAL_CLONE)
103
- result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
104
- logger.info("Face overlay completed.")
105
- return Image.fromarray(result_rgb)
106
- except Exception as e:
107
- logger.error(f"Error in overlay_face: {str(e)}", exc_info=True)
108
- raise HTTPException(status_code=500, detail=f"Overlay failed: {str(e)}")
109
-
110
- def image_to_base64(image: Image.Image) -> str:
111
- try:
112
- buffered = BytesIO()
113
- image.save(buffered, format="JPEG") # Changed to JPEG to match Framer client expectation
114
- img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
115
- logger.info("Image converted to base64 successfully.")
116
- return img_base64
117
- except Exception as e:
118
- logger.error(f"Error converting image to base64: {str(e)}", exc_info=True)
119
- raise HTTPException(status_code=500, detail=f"Base64 conversion failed: {str(e)}")
120
-
121
- @app.post("/predict")
122
- async def predict(
123
- prompt: str = Form(...),
124
- image: UploadFile = File(...),
125
- negative_prompt: str = Form("low quality, blurry"),
126
- seed: str = Form("66"), # Changed to str to match Framer client
127
- guidance_scale: str = Form("7.5"), # Changed to str to match Framer client
128
- num_inference_steps: str = Form("10"), # Changed to str to match Framer client
129
- strength: str = Form("0.75") # Changed to str to match Framer client
130
- ):
131
- global pipe
132
- try:
133
- if pipe is None:
134
- logger.error("Pipeline not initialized.")
135
- raise HTTPException(status_code=500, detail="Pipeline not initialized.")
136
-
137
- logger.info(f"Received inference request with prompt: {prompt}")
138
-
139
- # Convert string parameters to appropriate types
140
- try:
141
- seed = int(seed)
142
- guidance_scale = float(guidance_scale)
143
- num_inference_steps = int(num_inference_steps)
144
- strength = float(strength)
145
- except ValueError as e:
146
- logger.error(f"Invalid parameter format: {str(e)}")
147
- raise HTTPException(status_code=400, detail=f"Invalid parameter format: {str(e)}")
148
-
149
- # Load and process uploaded image
150
- logger.info("Loading uploaded image...")
151
- image_data = await image.read()
152
- ref_image = Image.open(BytesIO(image_data)).convert("RGB")
153
-
154
- # Extract face
155
- logger.info("Extracting face...")
156
- face_img, face_coords, original_size = extract_face(ref_image)
157
-
158
- # Resize reference image to 512x512
159
- logger.info("Resizing reference image...")
160
- init_image = ref_image.resize((512, 512))
161
-
162
- # Generate image with Stable Diffusion
163
- logger.info("Starting image generation...")
164
- generator = torch.Generator(device=device).manual_seed(seed)
165
- generated_img = pipe(
166
- prompt=prompt,
167
- image=init_image,
168
- strength=strength,
169
- num_inference_steps=num_inference_steps,
170
- guidance_scale=guidance_scale,
171
- negative_prompt=negative_prompt,
172
- generator=generator
173
- ).images[0]
174
- logger.info("Image generation completed.")
175
-
176
- # Scale face coordinates and overlay
177
- logger.info("Scaling coordinates and overlaying face...")
178
- scaled_coords = scale_coords(face_coords, original_size)
179
- final_img = overlay_face(generated_img, face_img, scaled_coords)
180
-
181
- # Convert to base64
182
- logger.info("Converting final image to base64...")
183
- result_base64 = image_to_base64(final_img)
184
-
185
- logger.info("Inference completed successfully.")
186
- return JSONResponse({
187
- "result_image": f"data:image/jpeg;base64,{result_base64}" # Match Framer client expectation
188
- })
189
- except HTTPException as e:
190
- logger.error(f"HTTP Exception: {str(e)}")
191
- raise
192
- except Exception as e:
193
- logger.error(f"Unexpected error in predict: {str(e)}", exc_info=True)
194
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
195
-
196
- @app.get("/")
197
- async def root():
198
- return {"status": "API is running"}
199
-
200
- if __name__ == "__main__":
201
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import cv2
2
+ import gradio as gr
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ from torch.autograd import Variable
8
+ from torchvision import transforms
9
+ import torch.nn.functional as F
10
+ import gdown
11
+ import matplotlib.pyplot as plt
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ os.system("git clone https://github.com/xuebinqin/DIS")
16
+ os.system("mv DIS/IS-Net/* .")
17
+
18
+ # project imports
19
+ from data_loader_cache import normalize, im_reader, im_preprocess
20
+ from models import *
21
+
22
+ #Helpers
23
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
+
25
+ # Download official weights
26
+ if not os.path.exists("saved_models"):
27
+ os.mkdir("saved_models")
28
+ os.system("mv isnet.pth saved_models/")
29
+
30
+ class GOSNormalize(object):
31
+ '''
32
+ Normalize the Image using torch.transforms
33
+ '''
34
+ def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
35
+ self.mean = mean
36
+ self.std = std
37
+
38
+ def __call__(self,image):
39
+ image = normalize(image,self.mean,self.std)
40
+ return image
41
+
42
+
43
+ transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
44
+
45
+ def load_image(im_path, hypar):
46
+ im = im_reader(im_path)
47
+ im, im_shp = im_preprocess(im, hypar["cache_size"])
48
+ im = torch.divide(im,255.0)
49
+ shape = torch.from_numpy(np.array(im_shp))
50
+ return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
51
+
52
+
53
+ def build_model(hypar,device):
54
+ net = hypar["model"]#GOSNETINC(3,1)
55
+
56
+ # convert to half precision
57
+ if(hypar["model_digit"]=="half"):
58
+ net.half()
59
+ for layer in net.modules():
60
+ if isinstance(layer, nn.BatchNorm2d):
61
+ layer.float()
62
+
63
+ net.to(device)
64
+
65
+ if(hypar["restore_model"]!=""):
66
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
67
+ net.to(device)
68
+ net.eval()
69
+ return net
70
+
71
+
72
+ def predict(net, inputs_val, shapes_val, hypar, device):
73
+ '''
74
+ Given an Image, predict the mask
75
+ '''
76
+ net.eval()
77
+
78
+ if(hypar["model_digit"]=="full"):
79
+ inputs_val = inputs_val.type(torch.FloatTensor)
80
+ else:
81
+ inputs_val = inputs_val.type(torch.HalfTensor)
82
+
83
+
84
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
85
+
86
+ ds_val = net(inputs_val_v)[0] # list of 6 results
87
+
88
+ pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
89
+
90
+ ## recover the prediction spatial size to the orignal image size
91
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
92
+
93
+ ma = torch.max(pred_val)
94
+ mi = torch.min(pred_val)
95
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
96
+
97
+ if device == 'cuda': torch.cuda.empty_cache()
98
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
99
+
100
+ # Set Parameters
101
+ hypar = {} # paramters for inferencing
102
+
103
+
104
+ hypar["model_path"] ="./saved_models" ## load trained weights from this path
105
+ hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
106
+ hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
107
+
108
+ ## choose floating point accuracy --
109
+ hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
110
+ hypar["seed"] = 0
111
+
112
+ hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
113
+
114
+ ## data augmentation parameters ---
115
+ hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
116
+ hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
117
+
118
+ hypar["model"] = ISNetDIS()
119
+
120
+ # Build Model
121
+ net = build_model(hypar, device)
122
+
123
+
124
+ def inference(image):
125
+ image_path = image
126
+
127
+ image_tensor, orig_size = load_image(image_path, hypar)
128
+ mask = predict(net, image_tensor, orig_size, hypar, device)
129
+
130
+ pil_mask = Image.fromarray(mask).convert('L')
131
+ im_rgb = Image.open(image).convert("RGB")
132
+
133
+ im_rgba = im_rgb.copy()
134
+ im_rgba.putalpha(pil_mask)
135
+
136
+ return [im_rgba, pil_mask]
137
+
138
+
139
+ title = "Highly Accurate Dichotomous Image Segmentation"
140
+ description = "This is an unofficial demo for DIS, a model that can remove the background from a given image. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.<br>GitHub: https://github.com/xuebinqin/DIS<br>Telegram bot: https://t.me/restoration_photo_bot<br>[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)"
141
+ article = "<div><center><img src='https://visitor-badge.glitch.me/badge?page_id=max_skobeev_dis_cmp_public' alt='visitor badge'></center></div>"
142
+
143
+ interface = gr.Interface(
144
+ fn=inference,
145
+ inputs=gr.Image(type='filepath'),
146
+ outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png")],
147
+ examples=[['robot.png'], ['ship.png']],
148
+ title=title,
149
+ description=description,
150
+ article=article,
151
+ flagging_mode="never",
152
+ cache_mode="lazy",
153
+ ).queue(api_open=True).launch(show_error=True, show_api=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,11 +1,8 @@
1
- torch==2.0.1
2
- diffusers==0.29.2
3
- transformers==4.41.2
4
- numpy==1.26.4
5
- Pillow==10.3.0
6
- opencv-python-headless==4.9.0.80
7
- mediapipe==0.10.14
8
- fastapi==0.111.0
9
- uvicorn==0.30.1
10
- requests==2.32.3
11
- pydantic==2.7.4
 
1
+ torch
2
+ torchvision
3
+ requests
4
+ gdown
5
+ matplotlib
6
+ opencv-python
7
+ Pillow
8
+ scikit-image
 
 
 
robot.png ADDED

Git LFS Details

  • SHA256: f2c1f7f0da9ec158a9a417198944afc378eed47a4499fb739288e884484384ef
  • Pointer size: 131 Bytes
  • Size of remote file: 818 kB
ship.png ADDED

Git LFS Details

  • SHA256: fee77596bad08603c301088e62ff5ae763caa7bcdb53ce7fe3e23ff40dabcf16
  • Pointer size: 131 Bytes
  • Size of remote file: 834 kB