Ashrafb commited on
Commit
0fe2699
·
verified ·
1 Parent(s): 708cc78

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -112
  2. main.py +80 -0
app.py DELETED
@@ -1,112 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torchvision.transforms.functional import normalize
5
- from huggingface_hub import hf_hub_download
6
- import gradio as gr
7
- from gradio_imageslider import ImageSlider
8
- from briarmbg import BriaRMBG
9
- import PIL
10
- from PIL import Image
11
- from typing import Tuple
12
-
13
- net=BriaRMBG()
14
- # model_path = "./model1.pth"
15
- model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
- if torch.cuda.is_available():
17
- net.load_state_dict(torch.load(model_path))
18
- net=net.cuda()
19
- else:
20
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
- net.eval()
22
-
23
-
24
- def resize_image(image):
25
- image = image.convert('RGB')
26
- model_input_size = (1024, 1024)
27
- image = image.resize(model_input_size, Image.BILINEAR)
28
- return image
29
-
30
-
31
- def process(image):
32
-
33
- # prepare input
34
- orig_image = Image.fromarray(image)
35
- w,h = orig_im_size = orig_image.size
36
- image = resize_image(orig_image)
37
- im_np = np.array(image)
38
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
- im_tensor = torch.unsqueeze(im_tensor,0)
40
- im_tensor = torch.divide(im_tensor,255.0)
41
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
- if torch.cuda.is_available():
43
- im_tensor=im_tensor.cuda()
44
-
45
- #inference
46
- result=net(im_tensor)
47
- # post process
48
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
- ma = torch.max(result)
50
- mi = torch.min(result)
51
- result = (result-mi)/(ma-mi)
52
- # image to pil
53
- im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
- pil_im = Image.fromarray(np.squeeze(im_array))
55
- # paste the mask on the original image
56
- new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
- new_im.paste(orig_image, mask=pil_im)
58
- # new_orig_image = orig_image.convert('RGBA')
59
-
60
- return new_im
61
- # return [new_orig_image, new_im]
62
-
63
-
64
- # block = gr.Blocks().queue()
65
-
66
- # with block:
67
- # gr.Markdown("## BRIA RMBG 1.4")
68
- # gr.HTML('''
69
- # <p style="margin-bottom: 10px; font-size: 94%">
70
- # This is a demo for BRIA RMBG 1.4 that using
71
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
72
- # </p>
73
- # ''')
74
- # with gr.Row():
75
- # with gr.Column():
76
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
77
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
78
- # run_button = gr.Button(value="Run")
79
-
80
- # with gr.Column():
81
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
82
- # ips = [input_image]
83
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
84
-
85
- # block.launch(debug = True)
86
-
87
- # block = gr.Blocks().queue()
88
-
89
- gr.Markdown("## BRIA RMBG 1.4")
90
- gr.HTML('''
91
- <p style="margin-bottom: 10px; font-size: 94%">
92
- This is a demo for BRIA RMBG 1.4 that using
93
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
94
- </p>
95
- ''')
96
-
97
- title = ""
98
-
99
-
100
-
101
- description = r"""
102
- """
103
- examples = [['./input.jpeg'],]
104
-
105
-
106
-
107
- # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
108
- # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
109
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
110
-
111
- if __name__ == "__main__":
112
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi import FastAPI, File, UploadFile, Form, Request
3
+ from fastapi.responses import HTMLResponse, FileResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.responses import StreamingResponse
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torchvision.transforms.functional import normalize
13
+ from huggingface_hub import hf_hub_download
14
+ from briarmbg import BriaRMBG
15
+ import PIL
16
+ from PIL import Image
17
+ import io
18
+
19
+ app = FastAPI()
20
+
21
+ net = BriaRMBG()
22
+ model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
23
+
24
+ if torch.cuda.is_available():
25
+ net.load_state_dict(torch.load(model_path))
26
+ net = net.cuda()
27
+ else:
28
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
29
+ net.eval()
30
+
31
+ def resize_image(image):
32
+ image = image.convert('RGB')
33
+ model_input_size = (1024, 1024)
34
+ image = image.resize(model_input_size, Image.BILINEAR)
35
+ return image
36
+
37
+ def process_image(image):
38
+ orig_image = image
39
+ w, h = orig_image.size
40
+ image = resize_image(orig_image)
41
+ im_np = np.array(image)
42
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
43
+ im_tensor = torch.unsqueeze(im_tensor, 0)
44
+ im_tensor = torch.divide(im_tensor, 255.0)
45
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
46
+
47
+ if torch.cuda.is_available():
48
+ im_tensor = im_tensor.cuda()
49
+
50
+ result = net(im_tensor)
51
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
52
+ ma = torch.max(result)
53
+ mi = torch.min(result)
54
+ result = (result - mi) / (ma - mi)
55
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
56
+ pil_im = Image.fromarray(np.squeeze(im_array))
57
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
58
+ new_im.paste(orig_image, mask=pil_im)
59
+
60
+ return new_im
61
+
62
+ @app.post("/process-image/")
63
+ async def process_image_endpoint(file: UploadFile = File(...)):
64
+ contents = await file.read()
65
+ pil_image = Image.open(io.BytesIO(contents))
66
+ processed_image = process_image(pil_image)
67
+
68
+ # Save the processed image temporarily
69
+ temp_file_path = "processed_image.png"
70
+ processed_image.save(temp_file_path)
71
+
72
+ # Return the processed image
73
+ return FileResponse(temp_file_path, media_type="image/png")
74
+
75
+
76
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
77
+
78
+ @app.get("/")
79
+ def index() -> FileResponse:
80
+ return FileResponse(path="/app/static/index.html", media_type="text/html")