sajabdoli commited on
Commit
cb317f0
·
verified ·
1 Parent(s): fd54065

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -19
app.py CHANGED
@@ -1,14 +1,26 @@
1
- from fastapi import FastAPI, File, UploadFile
 
2
  from segment_anything import sam_model_registry, SamPredictor
3
  from PIL import Image
4
  import numpy as np
5
  import torch
6
  import io
 
 
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
 
 
10
  # Load SAM Model
11
- sam_checkpoint = "sam_vit_b.pth" # Add the weights file manually in the Space
12
  model_type = "vit_b"
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -21,22 +33,101 @@ def read_root():
21
 
22
  @app.post("/segment")
23
  async def segment_image(file: UploadFile = File(...)):
24
- image_bytes = await file.read()
25
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
26
- image_np = np.array(image)
27
-
28
- predictor.set_image(image_np)
29
-
30
- input_point = np.array([[100, 100]])
31
- input_label = np.array([1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- masks, scores, _ = predictor.predict(
34
- point_coords=input_point,
35
- point_labels=input_label,
36
- multimask_output=False
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- return {
40
- "score": float(scores[0]),
41
- "mask": masks[0].tolist()
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from segment_anything import sam_model_registry, SamPredictor
4
  from PIL import Image
5
  import numpy as np
6
  import torch
7
  import io
8
+ import base64
9
+ import json
10
 
11
  app = FastAPI()
12
 
13
+ # Add CORS middleware for CVAT
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
  # Load SAM Model
23
+ sam_checkpoint = "sam_vit_b.pth"
24
  model_type = "vit_b"
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
33
 
34
  @app.post("/segment")
35
  async def segment_image(file: UploadFile = File(...)):
36
+ try:
37
+ image_bytes = await file.read()
38
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
39
+ image_np = np.array(image)
40
+
41
+ # Get image dimensions
42
+ height, width = image_np.shape[:2]
43
+
44
+ # Use center point instead of fixed point
45
+ center_point = np.array([[width // 2, height // 2]])
46
+ input_label = np.array([1])
47
+
48
+ predictor.set_image(image_np)
49
+ masks, scores, _ = predictor.predict(
50
+ point_coords=center_point,
51
+ point_labels=input_label,
52
+ multimask_output=True # Return multiple masks
53
+ )
54
+
55
+ # Return the best mask
56
+ best_mask_idx = np.argmax(scores)
57
+ mask = masks[best_mask_idx].astype(bool)
58
+
59
+ return {
60
+ "score": float(scores[best_mask_idx]),
61
+ "mask": mask.tolist()
62
+ }
63
+ except Exception as e:
64
+ raise HTTPException(status_code=500, detail=str(e))
65
 
66
+ # CVAT-specific endpoint
67
+ @app.post("/predict")
68
+ async def predict_for_cvat(body: str = Form(...)):
69
+ try:
70
+ data = json.loads(body)
71
+ image_data = data.get('image', '')
72
+
73
+ # Decode base64 image
74
+ image_bytes = base64.b64decode(image_data)
75
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
76
+ image_np = np.array(image)
77
+
78
+ # Get points from CVAT request
79
+ points = data.get('points', [])
80
+ if not points:
81
+ # If no points, use center of image
82
+ height, width = image_np.shape[:2]
83
+ points = [[width // 2, height // 2]]
84
+
85
+ input_points = np.array(points)
86
+ input_labels = np.ones(len(points))
87
+
88
+ predictor.set_image(image_np)
89
+ masks, scores, _ = predictor.predict(
90
+ point_coords=input_points,
91
+ point_labels=input_labels,
92
+ multimask_output=True
93
+ )
94
+
95
+ # Get best mask
96
+ best_mask_idx = np.argmax(scores)
97
+ mask = masks[best_mask_idx].astype(bool)
98
+
99
+ # Convert mask to CVAT format
100
+ height, width = mask.shape
101
+ rle = mask_to_rle(mask)
102
+
103
+ return {
104
+ "annotations": [{
105
+ "ObjectID": 1,
106
+ "ObjectScore": float(scores[best_mask_idx]),
107
+ "RLE": rle,
108
+ "PredictionType": "mask",
109
+ "width": width,
110
+ "height": height
111
+ }]
112
+ }
113
+ except Exception as e:
114
+ raise HTTPException(status_code=500, detail=str(e))
115
 
116
+ # Helper function to convert mask to RLE (Run-Length Encoding)
117
+ def mask_to_rle(mask):
118
+ """Convert mask to RLE format expected by CVAT"""
119
+ flattened_mask = mask.flatten()
120
+ rle = []
121
+ current_pixel = 0
122
+ count = 0
123
+
124
+ for pixel in flattened_mask:
125
+ if pixel == current_pixel:
126
+ count += 1
127
+ else:
128
+ rle.append(count)
129
+ current_pixel = pixel
130
+ count = 1
131
+
132
+ rle.append(count)
133
+ return rle