Update app.py
Browse files
app.py
CHANGED
|
@@ -61,6 +61,33 @@ def extract_image(pos_prompts, neg_prompts, img, threshold):
|
|
| 61 |
output_image.paste(img, mask=final_mask)
|
| 62 |
return output_image, final_mask
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# Gradio UI
|
| 65 |
with gr.Blocks() as demo:
|
| 66 |
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
|
|
|
|
| 61 |
output_image.paste(img, mask=final_mask)
|
| 62 |
return output_image, final_mask
|
| 63 |
|
| 64 |
+
|
| 65 |
+
@app.route('/api', methods=['POST'])
|
| 66 |
+
def api():
|
| 67 |
+
data = request.form
|
| 68 |
+
img_url = data['input_image']
|
| 69 |
+
positive_prompts = data['positive_prompts']
|
| 70 |
+
negative_prompts = data['negative_prompts']
|
| 71 |
+
threshold = float(data['input_slider_T'])
|
| 72 |
+
|
| 73 |
+
# Download image from URL
|
| 74 |
+
response = requests.get(img_url)
|
| 75 |
+
img = Image.open(BytesIO(response.content))
|
| 76 |
+
|
| 77 |
+
# Process image
|
| 78 |
+
masks = get_masks(positive_prompts, negative_prompts, img, threshold)
|
| 79 |
+
final_mask = np.any(np.stack(masks), axis=0)
|
| 80 |
+
|
| 81 |
+
# Convert mask to image
|
| 82 |
+
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
|
| 83 |
+
|
| 84 |
+
# Convert the final image to bytes
|
| 85 |
+
img_bytes = BytesIO()
|
| 86 |
+
final_mask.save(img_bytes, format='PNG')
|
| 87 |
+
img_bytes.seek(0)
|
| 88 |
+
|
| 89 |
+
return send_file(img_bytes, mimetype='image/png')
|
| 90 |
+
|
| 91 |
# Gradio UI
|
| 92 |
with gr.Blocks() as demo:
|
| 93 |
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
|