ivanm151 commited on
Commit
f8ed5ba
Β·
1 Parent(s): 777d510

mobilesam v1.0

Browse files
Files changed (5) hide show
  1. app.py +45 -28
  2. models.py +6 -11
  3. requirements.txt +2 -1
  4. utils.py +13 -102
  5. weights/{seg0.pth β†’ mobile_sam.pt} +2 -2
app.py CHANGED
@@ -1,30 +1,26 @@
1
  from fastapi import FastAPI, UploadFile, File, Query
2
  import torch
 
 
 
 
3
  from models import load_model1, load_model2, load_model3
4
  from utils import (
5
- preprocess_image,
6
- predict_mask_tta,
7
- postprocess_mask,
8
- mask_to_base64,
9
- apply_white_background_and_crop,
10
  preprocess_for_classifier,
11
  FRUIT_CLASSES,
12
  FRESHNESS_CLASSES
13
  )
14
- import numpy as np
15
- from PIL import Image
16
- import io
17
 
18
  app = FastAPI()
19
 
20
- # Π“Π»ΠΎΠ±Π°Π»ΡŒΠ½Π°Ρ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»Π΅ΠΉ
21
- model1 = load_model1() # segmentation (448)
22
- model2 = load_model2() # fruit type
23
- model3 = load_model3() # freshness
24
 
25
  DEVICE = torch.device('cpu')
26
 
27
- # ΠšΠ»Π°ΡΡΡ‹, для ΠΊΠΎΡ‚ΠΎΡ€Ρ‹Ρ… Π΄Π΅Π»Π°Π΅ΠΌ ΡΠ²Π΅ΠΆΠ΅ΡΡ‚ΡŒ
28
  FRESHNESS_ELIGIBLE = {'apple', 'banana', 'orange', 'lemon'}
29
 
30
  @app.get("/")
@@ -34,19 +30,33 @@ def greet_json():
34
  @app.post("/predict_full")
35
  async def predict_full(
36
  file: UploadFile = File(...),
37
- return_mask: bool = Query(default=False, description="Π’Π΅Ρ€Π½ΡƒΡ‚ΡŒ base64 маски?")
 
 
38
  ):
39
  content = await file.read()
40
  image = Image.open(io.BytesIO(content)).convert('RGB')
41
  orig_np = np.array(image)
42
 
43
- # БСгмСнтация
44
- input_tensor = preprocess_image(orig_np).unsqueeze(0).to(DEVICE)
45
- with torch.no_grad():
46
- prob = predict_mask_tta(model1, input_tensor)
47
- mask = postprocess_mask(prob.squeeze().cpu().numpy())
48
 
49
- fruit_area_ratio = np.mean(mask > 0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if fruit_area_ratio < 0.01:
51
  return {
52
  "status": "no_fruit_detected",
@@ -55,12 +65,11 @@ async def predict_full(
55
  "fruit_confidence": None,
56
  "freshness": None,
57
  "freshness_confidence": None,
58
- "mask_base64": mask_to_base64(mask) if return_mask else None
59
  }
60
 
61
- # Для сорта (100Γ—100)
62
- cropped_100 = apply_white_background_and_crop(orig_np, mask, out_size=100)
63
-
64
  input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
65
  with torch.no_grad():
66
  logits2 = model2(input_tensor2)
@@ -77,13 +86,12 @@ async def predict_full(
77
  "fruit_confidence": round(fruit_conf, 4),
78
  "freshness": None,
79
  "freshness_confidence": None,
80
- "mask_base64": mask_to_base64(mask) if return_mask else None
81
  }
82
 
83
- # Если Ρ„Ρ€ΡƒΠΊΡ‚ ΠΏΠΎΠ΄Ρ…ΠΎΠ΄ΠΈΡ‚ β€” ΡΠ²Π΅ΠΆΠ΅ΡΡ‚ΡŒ (224Γ—224)
84
  if fruit_name in FRESHNESS_ELIGIBLE:
85
- cropped_224 = apply_white_background_and_crop(orig_np, mask, out_size=224)
86
-
87
  input_tensor3 = preprocess_for_classifier(cropped_224).unsqueeze(0).to(DEVICE)
88
  with torch.no_grad():
89
  logits3 = model3(input_tensor3)
@@ -96,4 +104,13 @@ async def predict_full(
96
  result["freshness"] = fresh_name
97
  result["freshness_confidence"] = round(fresh_conf, 4)
98
 
 
 
 
 
 
 
 
 
 
99
  return result
 
1
  from fastapi import FastAPI, UploadFile, File, Query
2
  import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import base64
6
+ import io
7
  from models import load_model1, load_model2, load_model3
8
  from utils import (
9
+ crop_fruit_with_white_bg,
 
 
 
 
10
  preprocess_for_classifier,
11
  FRUIT_CLASSES,
12
  FRESHNESS_CLASSES
13
  )
 
 
 
14
 
15
  app = FastAPI()
16
 
17
+ # Π—Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»Π΅ΠΉ
18
+ sam_predictor = load_model1() # MobileSAM
19
+ model2 = load_model2()
20
+ model3 = load_model3()
21
 
22
  DEVICE = torch.device('cpu')
23
 
 
24
  FRESHNESS_ELIGIBLE = {'apple', 'banana', 'orange', 'lemon'}
25
 
26
  @app.get("/")
 
30
  @app.post("/predict_full")
31
  async def predict_full(
32
  file: UploadFile = File(...),
33
+ point_x: int = Query(..., description="X-ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Π° Ρ‚ΠΎΡ‡ΠΊΠΈ Π½Π° Ρ„Ρ€ΡƒΠΊΡ‚Π΅ (Π² пиксСлях ΠΎΡ€ΠΈΠ³ΠΈΠ½Π°Π»ΡŒΠ½ΠΎΠ³ΠΎ изобраТСния)"),
34
+ point_y: int = Query(..., description="Y-ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Π° Ρ‚ΠΎΡ‡ΠΊΠΈ Π½Π° Ρ„Ρ€ΡƒΠΊΡ‚Π΅"),
35
+ return_cropped: bool = Query(default=True, description="Π’Π΅Ρ€Π½ΡƒΡ‚ΡŒ ΠΎΠ±Ρ€Π΅Π·Π°Π½Π½ΠΎΠ΅ ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠ΅ Π² base64?")
36
  ):
37
  content = await file.read()
38
  image = Image.open(io.BytesIO(content)).convert('RGB')
39
  orig_np = np.array(image)
40
 
41
+ # Установка изобраТСния Π² SAM
42
+ sam_predictor.set_image(orig_np)
 
 
 
43
 
44
+ # ΠŸΡ€ΠΎΠΌΠΏΡ‚: Ρ‚ΠΎΡ‡ΠΊΠ° Π½Π° Ρ„Ρ€ΡƒΠΊΡ‚Π΅
45
+ input_point = np.array([[point_x, point_y]])
46
+ input_label = np.array([1]) # 1 = foreground
47
+
48
+ masks, scores, _ = sam_predictor.predict(
49
+ point_coords=input_point,
50
+ point_labels=input_label,
51
+ multimask_output=False # Одна маска
52
+ )
53
+
54
+ # Π‘Π΅Ρ€Ρ‘ΠΌ Π»ΡƒΡ‡ΡˆΡƒΡŽ маску
55
+ best_mask_idx = np.argmax(scores)
56
+ mask = masks[best_mask_idx] # bool
57
+
58
+ # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ°: Π΅ΡΡ‚ΡŒ Π»ΠΈ Ρ„Ρ€ΡƒΠΊΡ‚?
59
+ fruit_area_ratio = np.mean(mask)
60
  if fruit_area_ratio < 0.01:
61
  return {
62
  "status": "no_fruit_detected",
 
65
  "fruit_confidence": None,
66
  "freshness": None,
67
  "freshness_confidence": None,
68
+ "cropped_base64": None
69
  }
70
 
71
+ # ΠžΠ±Ρ€Π΅Π·ΠΊΠ° ΠΏΠΎΠ΄ 100Γ—100 для сорта
72
+ cropped_100 = crop_fruit_with_white_bg(orig_np, mask, out_size=100)
 
73
  input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
74
  with torch.no_grad():
75
  logits2 = model2(input_tensor2)
 
86
  "fruit_confidence": round(fruit_conf, 4),
87
  "freshness": None,
88
  "freshness_confidence": None,
89
+ "cropped_base64": None
90
  }
91
 
92
+ # Π‘Π²Π΅ΠΆΠ΅ΡΡ‚ΡŒ, Ссли ΠΏΠΎΠ΄Ρ…ΠΎΠ΄ΠΈΡ‚
93
  if fruit_name in FRESHNESS_ELIGIBLE:
94
+ cropped_224 = crop_fruit_with_white_bg(orig_np, mask, out_size=224)
 
95
  input_tensor3 = preprocess_for_classifier(cropped_224).unsqueeze(0).to(DEVICE)
96
  with torch.no_grad():
97
  logits3 = model3(input_tensor3)
 
104
  result["freshness"] = fresh_name
105
  result["freshness_confidence"] = round(fresh_conf, 4)
106
 
107
+ # Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ ΠΎΠ±Ρ€Π΅Π·Π°Π½Π½ΠΎΠ΅ ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠ΅ (ΠΏΠΎ ΡƒΠΌΠΎΠ»Ρ‡Π°Π½ΠΈΡŽ 224Γ—224)
108
+ if return_cropped:
109
+ cropped_final = crop_fruit_with_white_bg(orig_np, mask, out_size=224)
110
+ pil_img = Image.fromarray(cropped_final)
111
+ buffered = io.BytesIO()
112
+ pil_img.save(buffered, format="PNG")
113
+ result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
114
+ result["cropped_size"] = "224x224"
115
+
116
  return result
models.py CHANGED
@@ -2,25 +2,20 @@ import torch
2
  import torchvision.models as models
3
  import torch.nn as nn
4
  import segmentation_models_pytorch as smp
 
5
 
6
  DEVICE = torch.device('cpu')
7
 
8
- model1 = None # сСгмСнтация
9
  model2 = None # сорт Ρ„Ρ€ΡƒΠΊΡ‚Π°
10
  model3 = None # ΡΠ²Π΅ΠΆΠ΅ΡΡ‚ΡŒ
11
 
12
- def load_model1(weights_path='weights/seg0.pth'):
13
  global model1
14
  if model1 is None:
15
- model1 = smp.Unet(
16
- encoder_name="mobilenet_v2",
17
- encoder_weights=None,
18
- in_channels=3,
19
- classes=1,
20
- activation=None
21
- ).to(DEVICE)
22
- state_dict = torch.load(weights_path, map_location=DEVICE)
23
- model1.load_state_dict(state_dict)
24
  model1.eval()
25
  return model1
26
 
 
2
  import torchvision.models as models
3
  import torch.nn as nn
4
  import segmentation_models_pytorch as smp
5
+ from mobile_sam import sam_model_registry, SamPredictor
6
 
7
  DEVICE = torch.device('cpu')
8
 
9
+ model1 = None # Ρ‚Π΅ΠΏΠ΅Ρ€ΡŒ это MobileSAM
10
  model2 = None # сорт Ρ„Ρ€ΡƒΠΊΡ‚Π°
11
  model3 = None # ΡΠ²Π΅ΠΆΠ΅ΡΡ‚ΡŒ
12
 
13
+ def load_model1(weights_path='weights/mobile_sam.pt'):
14
  global model1
15
  if model1 is None:
16
+ model_type = "vit_t"
17
+ model1 = sam_model_registry[model_type](checkpoint=weights_path)
18
+ model1.to(DEVICE)
 
 
 
 
 
 
19
  model1.eval()
20
  return model1
21
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ albumentations
7
  pillow
8
  numpy
9
  opencv-python-headless
10
- python-multipart
 
 
7
  pillow
8
  numpy
9
  opencv-python-headless
10
+ python-multipart
11
+ git+https://github.com/ChaoningZhang/MobileSAM.git
utils.py CHANGED
@@ -1,100 +1,16 @@
1
  import numpy as np
2
- import albumentations as A
3
- from albumentations.pytorch import ToTensorV2
4
- import torch
5
  import cv2
 
6
  from PIL import Image
7
  import io
8
  import base64
9
  from torchvision import transforms
 
10
 
11
- # ────────────────────────────────────────────────
12
- # Новый Ρ€Π°Π·ΠΌΠ΅Ρ€ Π²Ρ…ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ β€” 448Γ—448
13
- # ────────────────────────────────────────────────
14
- IMG_SIZE = 448
15
-
16
- preprocess_transform = A.Compose([
17
- A.Resize(IMG_SIZE, IMG_SIZE),
18
- A.Normalize(), # mean/std ImageNet β€” Ρ‚ΠΎ ΠΆΠ΅, Ρ‡Ρ‚ΠΎ ΠΈ Π² ΠΎΠ±ΡƒΡ‡Π΅Π½ΠΈΠΈ
19
- ToTensorV2()
20
- ])
21
-
22
- def preprocess_image(image_np: np.ndarray) -> torch.Tensor:
23
- augmented = preprocess_transform(image=image_np)
24
- return augmented['image']
25
-
26
- # ────────────────────────────────────────────────
27
- # TTA-прСдсказаниС (ΠΊΠ°ΠΊ Π² Ρ‚Π²ΠΎΡ‘ΠΌ ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π΅)
28
- # ────────────────────────────────────────────────
29
- @torch.no_grad()
30
- def predict_mask_tta(model, image_tensor):
31
- preds = []
32
- # ΠžΡ€ΠΈΠ³ΠΈΠ½Π°Π»
33
- preds.append(torch.sigmoid(model(image_tensor)))
34
- # Flip horizontal
35
- preds.append(
36
- torch.flip(
37
- torch.sigmoid(model(torch.flip(image_tensor, dims=[3]))),
38
- dims=[3]
39
- )
40
- )
41
- # Flip vertical
42
- preds.append(
43
- torch.flip(
44
- torch.sigmoid(model(torch.flip(image_tensor, dims=[2]))),
45
- dims=[2]
46
- )
47
- )
48
- return torch.mean(torch.stack(preds), dim=0)
49
-
50
- # ────────────────────────────────────────────────
51
- # Post-processing маски (ΠΊΠ°ΠΊ Π² Ρ‚Π²ΠΎΡ‘ΠΌ ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π΅ + морфология)
52
- # ────────────────────────────────────────────────
53
- def postprocess_mask(prob: np.ndarray, threshold: float = 0.65, min_area_ratio: float = 0.01) -> np.ndarray:
54
- binary = (prob > threshold).astype(np.uint8)
55
-
56
- # Connected components β€” оставляСм Ρ‚ΠΎΠ»ΡŒΠΊΠΎ Π³Π»Π°Π²Π½Ρ‹ΠΉ ΠΎΠ±ΡŠΠ΅ΠΊΡ‚
57
- num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8)
58
-
59
- if num_labels <= 1:
60
- return binary.astype(np.float32)
61
-
62
- largest_label = np.argmax(stats[1:, cv2.CC_STAT_AREA]) + 1
63
- area = stats[largest_label, cv2.CC_STAT_AREA]
64
-
65
- if area < binary.shape[0] * binary.shape[1] * min_area_ratio:
66
- return np.zeros_like(binary, dtype=np.float32)
67
-
68
- clean_mask = (labels == largest_label).astype(np.float32)
69
-
70
- # ΠœΠΎΡ€Ρ„ΠΎΠ»ΠΎΠ³ΠΈΡ (Π·Π°ΠΏΠΎΠ»Π½ΠΈΡ‚ΡŒ Π΄Ρ‹Ρ€ΠΊΠΈ, ΡƒΠ±Ρ€Π°Ρ‚ΡŒ ΡˆΡƒΠΌ)
71
- kernel = np.ones((3, 3), np.uint8)
72
- clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_CLOSE, kernel)
73
- clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_OPEN, kernel)
74
-
75
- return clean_mask
76
-
77
- # ────────────────────────────────────────────────
78
- # Base64 маски (для Π²ΠΎΠ·Π²Ρ€Π°Ρ‚Π° ΠΊΠ»ΠΈΠ΅Π½Ρ‚Ρƒ)
79
- # ────────────────────────────────────────────────
80
- def mask_to_base64(mask: np.ndarray) -> str:
81
- pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
82
- buffered = io.BytesIO()
83
- pil_mask.save(buffered, format="PNG")
84
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
85
-
86
- # ────────────────────────────────────────────────
87
- # ΠšΠΎΠ½ΡΡ‚Π°Π½Ρ‚Ρ‹ классов
88
- # ────────────────────────────────────────────────
89
  FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
90
- FRESHNESS_CLASSES = [
91
- 'freshapples', 'freshbanana', 'freshoranges',
92
- 'rottenapples', 'rottenbanana', 'rottenoranges'
93
- ]
94
 
95
- # ────────────────────────────────────────────────
96
- # Preprocess для классификаторов (100 ΠΈ 224)
97
- # ────────────────────────────────────────────────
98
  def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
99
  transform = transforms.Compose([
100
  transforms.ToPILImage(),
@@ -103,9 +19,7 @@ def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
103
  ])
104
  return transform(img)
105
 
106
- # ────────────────────────────────────────────────
107
- # Π£Π½ΠΈΠ²Π΅Ρ€ΡΠ°Π»ΡŒΠ½Ρ‹ΠΉ letterbox (для любого target_size)
108
- # ────────────────────────────────────────────────
109
  def letterbox_any_size(
110
  img: np.ndarray,
111
  target_size: int = 224,
@@ -124,22 +38,19 @@ def letterbox_any_size(
124
  left = pad_w // 2
125
  right = pad_w - left
126
 
127
- padded = cv2.copyMakeBorder(
128
- resized, top, bottom, left, right,
129
- cv2.BORDER_CONSTANT, value=bg_color
130
- )
131
  return padded
132
 
133
- # ────────────────────────────────────────────────
134
- # Apply white background + crop ΠΏΠΎ маскС (448Γ—448)
135
- # ────────────────────────────────────────────────
136
- def apply_white_background_and_crop(
137
- orig_img: np.ndarray, # RGB
138
- mask: np.ndarray, # float [0,1] 448Γ—448
139
  out_size: int = 224,
140
  bg_color: tuple = (255, 255, 255)
141
  ) -> np.ndarray:
142
- mask_bin = (mask > 0.5).astype(np.uint8)
 
143
 
144
  ys, xs = np.where(mask_bin == 1)
145
  if len(xs) == 0:
 
1
  import numpy as np
 
 
 
2
  import cv2
3
+ import torch
4
  from PIL import Image
5
  import io
6
  import base64
7
  from torchvision import transforms
8
+ from mobile_sam import SamPredictor
9
 
10
+ # ΠšΠΎΠ½ΡΡ‚Π°Π½Ρ‚Ρ‹
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
12
+ FRESHNESS_CLASSES = ['freshapples', 'freshbanana', 'freshoranges', 'rottenapples', 'rottenbanana', 'rottenoranges']
 
 
 
13
 
 
 
 
14
  def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
15
  transform = transforms.Compose([
16
  transforms.ToPILImage(),
 
19
  ])
20
  return transform(img)
21
 
22
+ # Π£Π½ΠΈΠ²Π΅Ρ€ΡΠ°Π»ΡŒΠ½Ρ‹ΠΉ letterbox (Π±Π΅Π· искаТСния)
 
 
23
  def letterbox_any_size(
24
  img: np.ndarray,
25
  target_size: int = 224,
 
38
  left = pad_w // 2
39
  right = pad_w - left
40
 
41
+ padded = cv2.copyMakeBorder(resized, top, bottom, left, right,
42
+ cv2.BORDER_CONSTANT, value=bg_color)
 
 
43
  return padded
44
 
45
+ # ΠžΠ±Ρ€Π΅Π·ΠΊΠ° ΠΏΠΎ маскС SAM + Π±Π΅Π»Ρ‹ΠΉ Ρ„ΠΎΠ½ + letterbox
46
+ def crop_fruit_with_white_bg(
47
+ orig_img: np.ndarray, # RGB
48
+ mask: np.ndarray, # bool ΠΈΠ»ΠΈ uint8 ΠΎΡ‚ SAM
 
 
49
  out_size: int = 224,
50
  bg_color: tuple = (255, 255, 255)
51
  ) -> np.ndarray:
52
+ # Маска β†’ binary
53
+ mask_bin = mask.astype(np.uint8)
54
 
55
  ys, xs = np.where(mask_bin == 1)
56
  if len(xs) == 0:
weights/{seg0.pth β†’ mobile_sam.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4e2a778652280420b80ebc949fac8e2d1a95737d28884e6fa99df2509c7410db
3
- size 26806811
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f
3
+ size 40728226