File size: 1,444 Bytes
4a2ff96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from flask import Flask, request, jsonify
import numpy as np
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from cv2 import imencode
from base64 import b64encode
import requests
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
print("Loaded model")

app = Flask(__name__)

@app.route('/', methods=['POST'])
def index():
  app.logger.info('Got request !')
  start = time.time()
  input = request.json
  url = input.get('url')
  app.logger.info('Got request for url %s', url)

  image = np.array(Image.open(requests.get(url, stream=True).raw).convert("RGB"))
  masks = mask_generator.generate(image)

  data = []
  for mask in masks:
    mask_image = np.zeros(image.shape[:3], np.uint8)
    mask_image[mask["segmentation"] == True] = 255
    retval, buffer = imencode('.png', mask_image)
    encoded_mask = b64encode(buffer).decode("ascii")
    data.append({
      "label": "",
      "mask": encoded_mask,
      "score": mask["predicted_iou"]
    })
  
  end = time.time()

  return jsonify({ "data": data, "time": end - start })

@app.route('/health', methods=['GET'])
def health():
  return jsonify({ "success": True })
    
if __name__ == '__main__':
  app.run(host='0.0.0.0', port=8000)