Slicelayers commited on
Commit
4f71413
·
verified ·
1 Parent(s): 6024e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -16
app.py CHANGED
@@ -1,21 +1,43 @@
1
  import gradio as gr
2
- from rembg import remove
3
  from PIL import Image
4
- import io
5
-
6
- # 背景透過処理
7
- def remove_bg(img):
8
- # rembgで背景を削除
9
- result = remove(img)
10
- return result
11
-
12
- demo = gr.Interface(
13
- fn=remove_bg,
14
- inputs=gr.Image(type="pil"),
15
- outputs=gr.Image(type="pil"),
16
- title="背景透過テスト",
17
- description="アップロードした画像から背景を消して透過PNGを返します。"
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  if __name__ == "__main__":
21
  demo.launch()
 
1
  import gradio as gr
 
2
  from PIL import Image
3
+ import torch
4
+ from segment_anything import sam_model_registry, SamPredictor
5
+
6
+ # SAMの準備(モデルは小さい vit_b を利用)
7
+ sam_checkpoint = "sam_vit_b_01ec64.pth" # モデルファイル
8
+ model_type = "vit_b"
9
+
10
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
11
+ predictor = SamPredictor(sam)
12
+
13
+ def segment(img, points):
14
+ predictor.set_image(img)
15
+
16
+ input_points = []
17
+ for p in points:
18
+ input_points.append([p["x"], p["y"]])
19
+
20
+ input_labels = [1] * len(input_points) # クリック位置は前景扱い
21
+ masks, _, _ = predictor.predict(
22
+ point_coords=torch.tensor(input_points),
23
+ point_labels=torch.tensor(input_labels),
24
+ multimask_output=False,
25
+ )
26
+
27
+ mask = masks[0]
28
+ out = Image.fromarray((mask * 255).astype("uint8"))
29
+ return out
30
+
31
+ with gr.Blocks() as demo:
32
+ with gr.Row():
33
+ with gr.Column():
34
+ inp = gr.Image(type="pil", label="画像アップロード")
35
+ pts = gr.Point(label="クリックして分割位置指定")
36
+ btn = gr.Button("分割実行")
37
+ with gr.Column():
38
+ out = gr.Image(type="pil", label="分割結果")
39
+
40
+ btn.click(segment, inputs=[inp, pts], outputs=out)
41
 
42
  if __name__ == "__main__":
43
  demo.launch()