Update app.py
Browse files
app.py
CHANGED
|
@@ -45,10 +45,6 @@ def get_masks(prompts, img, threshold):
|
|
| 45 |
|
| 46 |
return masks
|
| 47 |
|
| 48 |
-
@app.route('/')
|
| 49 |
-
def hello_world():
|
| 50 |
-
return 'Hello, World!'
|
| 51 |
-
|
| 52 |
# Function to extract image using positive and negative prompts
|
| 53 |
def extract_image(pos_prompts, neg_prompts, img, threshold):
|
| 54 |
positive_masks = get_masks(pos_prompts, img, 0.5)
|
|
@@ -62,7 +58,16 @@ def extract_image(pos_prompts, neg_prompts, img, threshold):
|
|
| 62 |
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 63 |
output_image.paste(img, mask=final_mask)
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
@app.route('/api', methods=['POST'])
|
| 68 |
def process_request():
|
|
@@ -79,16 +84,10 @@ def process_request():
|
|
| 79 |
threshold = float(data.get('threshold', 0.4))
|
| 80 |
|
| 81 |
# Perform image segmentation
|
| 82 |
-
|
| 83 |
|
| 84 |
-
|
| 85 |
-
buffered = io.BytesIO()
|
| 86 |
-
output_image.save(buffered, format="PNG")
|
| 87 |
-
result_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 88 |
-
|
| 89 |
-
return jsonify({'result_image_base64': result_image_base64})
|
| 90 |
|
| 91 |
if __name__ == '__main__':
|
| 92 |
print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
|
| 93 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
| 94 |
-
|
|
|
|
| 45 |
|
| 46 |
return masks
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# Function to extract image using positive and negative prompts
|
| 49 |
def extract_image(pos_prompts, neg_prompts, img, threshold):
|
| 50 |
positive_masks = get_masks(pos_prompts, img, 0.5)
|
|
|
|
| 58 |
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 59 |
output_image.paste(img, mask=final_mask)
|
| 60 |
|
| 61 |
+
# Convert final mask to base64
|
| 62 |
+
buffered = io.BytesIO()
|
| 63 |
+
final_mask.save(buffered, format="PNG")
|
| 64 |
+
final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 65 |
+
|
| 66 |
+
return final_mask_base64
|
| 67 |
+
|
| 68 |
+
@app.route('/')
|
| 69 |
+
def hello_world():
|
| 70 |
+
return 'Hello, World!'
|
| 71 |
|
| 72 |
@app.route('/api', methods=['POST'])
|
| 73 |
def process_request():
|
|
|
|
| 84 |
threshold = float(data.get('threshold', 0.4))
|
| 85 |
|
| 86 |
# Perform image segmentation
|
| 87 |
+
final_mask_base64 = extract_image(pos_prompts, neg_prompts, img, threshold)
|
| 88 |
|
| 89 |
+
return jsonify({'final_mask_base64': final_mask_base64})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
if __name__ == '__main__':
|
| 92 |
print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
|
| 93 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
|
|