yurista commited on
Commit
8f7f834
Β·
verified Β·
1 Parent(s): 18a3c4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -101
app.py CHANGED
@@ -5,152 +5,109 @@ from PIL import Image
5
  import base64
6
  import io
7
  import os
8
-
9
  from segment_anything import sam_model_registry, SamPredictor
10
  from diffusers import StableDiffusionXLInpaintPipeline
11
 
12
-
13
- # ------------------- Load model -------------------
14
  MODEL_PATH = "sam_vit_b_01ec64.pth"
15
 
16
  if not os.path.exists(MODEL_PATH):
17
  os.system(f"wget https://dl.fbaipublicfiles.com/segment_anything/{MODEL_PATH}")
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
20
  sam = sam_model_registry["vit_b"](checkpoint=MODEL_PATH)
21
  sam.to(device=device)
22
  predictor = SamPredictor(sam)
 
23
 
24
- # ------------------- Load SDXL Inpainting -------------------
25
- print("πŸͺ„ Loading Stable Diffusion XL Inpainting Model...")
26
- from diffusers import StableDiffusionXLInpaintPipeline
27
-
28
- sdxl_model_id = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
29
  pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
30
- sdxl_model_id,
 
31
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
32
- variant="fp16",
33
  )
34
  pipe = pipe.to(device)
35
- print("βœ… SDXL Inpainting loaded successfully!")
36
 
37
- # ------------------- Helper -------------------
38
- def decode_base64_image(image_base64: str):
39
- """Decode base64 string menjadi PIL Image"""
40
- image_data = base64.b64decode(image_base64)
41
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
42
- return np.array(image)
43
 
44
-
45
- def encode_mask_to_base64(mask: np.ndarray):
46
- """Encode mask numpy array menjadi base64 PNG"""
47
- mask_image = Image.fromarray((mask * 255).astype(np.uint8))
48
  buffered = io.BytesIO()
49
- mask_image.save(buffered, format="PNG")
50
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
51
 
 
 
 
 
 
52
 
53
- # ------------------- Inference function -------------------
54
- def predict(image_base64, box=None, points=None, labels=None):
55
- """
56
- image_base64: string base64 dari gambar RGB
57
- box: [x1, y1, x2, y2] (optional)
58
- points: list of [x, y] (optional)
59
- labels: list of 1/0 (optional)
60
- """
61
-
62
  try:
63
- image_np = decode_base64_image(image_base64)
 
64
  predictor.set_image(image_np)
65
 
66
- box_np = np.array(box) if box else None
67
- points_np = np.array(points) if points else None
68
- labels_np = np.array(labels) if labels else None
 
69
 
70
- masks, scores, logits = predictor.predict(
71
- point_coords=points_np,
72
- point_labels=labels_np,
73
- box=box_np,
74
  multimask_output=True
75
  )
76
 
77
  best_idx = np.argmax(scores)
78
  mask = masks[best_idx]
 
79
 
80
- mask_base64 = encode_mask_to_base64(mask)
81
- return {"mask_base64": mask_base64, "score": float(scores[best_idx])}
82
-
83
- except Exception as e:
84
- return {"error": str(e)}
85
-
86
- def inpaint_background(image_base64, mask_base64, prompt, negative_prompt="", guidance_scale=7.5, steps=30, seed=42):
87
- """
88
- Mengganti background berdasarkan prompt menggunakan model SDXL Inpainting.
89
- image_base64: base64 dari gambar RGBA (foreground+alpha)
90
- mask_base64: base64 dari mask (putih=area yang diganti)
91
- """
92
- try:
93
- # Decode image
94
- image = decode_base64_image(image_base64)
95
- mask_data = base64.b64decode(mask_base64)
96
- mask = Image.open(io.BytesIO(mask_data)).convert("L")
97
-
98
  generator = torch.manual_seed(int(seed))
99
  result = pipe(
100
  prompt=prompt,
101
  negative_prompt=negative_prompt,
102
- image=Image.fromarray(image),
103
- mask_image=mask,
104
  guidance_scale=float(guidance_scale),
105
  num_inference_steps=int(steps),
106
  generator=generator
107
  ).images[0]
108
 
109
- buffered = io.BytesIO()
110
- result.save(buffered, format="PNG")
111
- result_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
112
-
113
- return {"result_base64": result_b64, "status": "βœ… Success"}
114
 
115
  except Exception as e:
116
  import traceback
117
- return {"error": str(e), "traceback": traceback.format_exc()}
118
-
119
 
120
  # ------------------- Gradio Interface -------------------
121
- demo = gr.Interface(
122
- fn=predict,
123
- inputs=[
124
- gr.Textbox(label="Image (Base64)", lines=5, placeholder="Base64 encoded image"),
125
- gr.Textbox(label="Box [x1, y1, x2, y2]", placeholder="[100, 100, 300, 300]"),
126
- gr.Textbox(label="Points (optional)", placeholder="[[120,150],[130,160]]"),
127
- gr.Textbox(label="Labels (optional)", placeholder="[1,0]"),
128
- ],
129
- outputs="json",
130
- title="SAM Segmentation API",
131
- description="API untuk segmentasi gambar menggunakan Segment Anything Model (SAM)."
132
- )
133
-
134
- demo2 = gr.Interface(
135
- fn=inpaint_background,
136
- inputs=[
137
- gr.Textbox(label="Image (Base64)", lines=3),
138
- gr.Textbox(label="Mask (Base64)", lines=3),
139
- gr.Textbox(label="Prompt", placeholder="beautiful beach background"),
140
- gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry"),
141
- gr.Slider(1, 20, value=7.5, step=0.5, label="Guidance Scale"),
142
- gr.Slider(10, 50, value=30, step=5, label="Steps"),
143
- gr.Number(value=42, label="Seed"),
144
- ],
145
- outputs="json",
146
- title="SDXL Background Inpainting API",
147
- description="API untuk mengganti background menggunakan Stable Diffusion XL Inpainting."
148
- )
149
-
150
-
151
- if __name__ == "__main__":
152
- app = gr.TabbedInterface(
153
- [demo, demo2],
154
- ["SAM Segmentation", "Background Inpainting"]
155
  )
156
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
5
  import base64
6
  import io
7
  import os
 
8
  from segment_anything import sam_model_registry, SamPredictor
9
  from diffusers import StableDiffusionXLInpaintPipeline
10
 
11
+ # ------------------- Load Models -------------------
 
12
  MODEL_PATH = "sam_vit_b_01ec64.pth"
13
 
14
  if not os.path.exists(MODEL_PATH):
15
  os.system(f"wget https://dl.fbaipublicfiles.com/segment_anything/{MODEL_PATH}")
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ print("🧠 Loading SAM model...")
20
  sam = sam_model_registry["vit_b"](checkpoint=MODEL_PATH)
21
  sam.to(device=device)
22
  predictor = SamPredictor(sam)
23
+ print("βœ… SAM loaded successfully!")
24
 
25
+ print("🎨 Loading SDXL Inpainting model...")
 
 
 
 
26
  pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
27
+ "stabilityai/stable-diffusion-xl-base-1.0",
28
+ revision="fp16",
29
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
30
  )
31
  pipe = pipe.to(device)
32
+ print("βœ… SDXL loaded successfully!")
33
 
34
+ # ------------------- Helper Functions -------------------
35
+ def np_to_pil(np_img):
36
+ return Image.fromarray((np_img * 255).astype(np.uint8)) if np_img.dtype == np.float32 else Image.fromarray(np_img)
 
 
 
37
 
38
+ def pil_to_b64(image: Image.Image):
 
 
 
39
  buffered = io.BytesIO()
40
+ image.save(buffered, format="PNG")
41
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
42
 
43
+ def decode_image(image):
44
+ if isinstance(image, str):
45
+ image_data = base64.b64decode(image)
46
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
47
+ return np.array(image)
48
 
49
+ # ------------------- Main Pipeline -------------------
50
+ def segment_and_inpaint(image, prompt, negative_prompt="", guidance_scale=7.5, steps=30, seed=42):
 
 
 
 
 
 
 
51
  try:
52
+ # Step 1: Segmentasi dengan SAM
53
+ image_np = np.array(image.convert("RGB"))
54
  predictor.set_image(image_np)
55
 
56
+ # Gunakan titik tengah gambar sebagai fokus sementara
57
+ h, w, _ = image_np.shape
58
+ points = np.array([[w // 2, h // 2]])
59
+ labels = np.array([1])
60
 
61
+ masks, scores, _ = predictor.predict(
62
+ point_coords=points,
63
+ point_labels=labels,
 
64
  multimask_output=True
65
  )
66
 
67
  best_idx = np.argmax(scores)
68
  mask = masks[best_idx]
69
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
70
 
71
+ # Step 2: Inpainting Background
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  generator = torch.manual_seed(int(seed))
73
  result = pipe(
74
  prompt=prompt,
75
  negative_prompt=negative_prompt,
76
+ image=image,
77
+ mask_image=mask_pil,
78
  guidance_scale=float(guidance_scale),
79
  num_inference_steps=int(steps),
80
  generator=generator
81
  ).images[0]
82
 
83
+ return result
 
 
 
 
84
 
85
  except Exception as e:
86
  import traceback
87
+ print(traceback.format_exc())
88
+ return f"❌ Error: {str(e)}"
89
 
90
  # ------------------- Gradio Interface -------------------
91
+ with gr.Blocks(title="SAM + SDXL Background Changer") as app:
92
+ gr.Markdown("## 🎨 Background Changer using SAM + Stable Diffusion XL Inpainting")
93
+
94
+ with gr.Row():
95
+ with gr.Column():
96
+ input_image = gr.Image(label="Upload Image", type="pil")
97
+ prompt = gr.Textbox(label="Prompt (background description)", placeholder="a beach at sunset")
98
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry")
99
+ guidance_scale = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance Scale")
100
+ steps = gr.Slider(10, 50, value=30, step=5, label="Inference Steps")
101
+ seed = gr.Number(value=42, label="Random Seed")
102
+ submit_btn = gr.Button("✨ Change Background")
103
+
104
+ with gr.Column():
105
+ output_image = gr.Image(label="Result", type="pil")
106
+
107
+ submit_btn.click(
108
+ fn=segment_and_inpaint,
109
+ inputs=[input_image, prompt, negative_prompt, guidance_scale, steps, seed],
110
+ outputs=[output_image]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
+
113
+ app.launch(server_name="0.0.0.0", server_port=7860)