ivanm151 commited on
Commit
b1e9f50
·
1 Parent(s): f8ed5ba

mobilesam v1.1

Browse files
Files changed (3) hide show
  1. app.py +14 -15
  2. models.py +11 -12
  3. utils.py +3 -7
app.py CHANGED
@@ -2,9 +2,9 @@ from fastapi import FastAPI, UploadFile, File, Query
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
- import base64
6
  import io
7
- from models import load_model1, load_model2, load_model3
 
8
  from utils import (
9
  crop_fruit_with_white_bg,
10
  preprocess_for_classifier,
@@ -15,9 +15,9 @@ from utils import (
15
  app = FastAPI()
16
 
17
  # Загрузка моделей
18
- sam_predictor = load_model1() # MobileSAM
19
- model2 = load_model2()
20
- model3 = load_model3()
21
 
22
  DEVICE = torch.device('cpu')
23
 
@@ -32,23 +32,22 @@ async def predict_full(
32
  file: UploadFile = File(...),
33
  point_x: int = Query(..., description="X-координата точки на фрукте (в пикселях оригинального изображения)"),
34
  point_y: int = Query(..., description="Y-координата точки на фрукте"),
35
- return_cropped: bool = Query(default=True, description="Вернуть обрезанное изображение в base64?")
 
36
  ):
37
  content = await file.read()
38
  image = Image.open(io.BytesIO(content)).convert('RGB')
39
  orig_np = np.array(image)
40
 
41
- # Установка изображения в SAM
42
- sam_predictor.set_image(orig_np)
43
-
44
- # Промпт: точка на фрукте
45
  input_point = np.array([[point_x, point_y]])
46
  input_label = np.array([1]) # 1 = foreground
47
 
48
- masks, scores, _ = sam_predictor.predict(
 
49
  point_coords=input_point,
50
  point_labels=input_label,
51
- multimask_output=False # Одна маска
52
  )
53
 
54
  # Берём лучшую маску
@@ -104,13 +103,13 @@ async def predict_full(
104
  result["freshness"] = fresh_name
105
  result["freshness_confidence"] = round(fresh_conf, 4)
106
 
107
- # Возвращаем обрезанное изображение (по умолчанию 224×224)
108
  if return_cropped:
109
- cropped_final = crop_fruit_with_white_bg(orig_np, mask, out_size=224)
110
  pil_img = Image.fromarray(cropped_final)
111
  buffered = io.BytesIO()
112
  pil_img.save(buffered, format="PNG")
113
  result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
114
- result["cropped_size"] = "224x224"
115
 
116
  return result
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
 
5
  import io
6
+ import base64
7
+ from models import load_sam, load_model2, load_model3
8
  from utils import (
9
  crop_fruit_with_white_bg,
10
  preprocess_for_classifier,
 
15
  app = FastAPI()
16
 
17
  # Загрузка моделей
18
+ sam_model = load_sam() # MobileSAM
19
+ model2 = load_model2() # сорт
20
+ model3 = load_model3() # свежесть
21
 
22
  DEVICE = torch.device('cpu')
23
 
 
32
  file: UploadFile = File(...),
33
  point_x: int = Query(..., description="X-координата точки на фрукте (в пикселях оригинального изображения)"),
34
  point_y: int = Query(..., description="Y-координата точки на фрукте"),
35
+ return_cropped: bool = Query(default=True, description="Вернуть обрезанное изображение в base64?"),
36
+ cropped_size: int = Query(224, description="Размер обрезанного изображения (100 или 224)")
37
  ):
38
  content = await file.read()
39
  image = Image.open(io.BytesIO(content)).convert('RGB')
40
  orig_np = np.array(image)
41
 
42
+ # MobileSAM: сегментация по точке
 
 
 
43
  input_point = np.array([[point_x, point_y]])
44
  input_label = np.array([1]) # 1 = foreground
45
 
46
+ masks, scores, _ = sam_model.predict(
47
+ image=orig_np,
48
  point_coords=input_point,
49
  point_labels=input_label,
50
+ multimask_output=False
51
  )
52
 
53
  # Берём лучшую маску
 
103
  result["freshness"] = fresh_name
104
  result["freshness_confidence"] = round(fresh_conf, 4)
105
 
106
+ # Возвращаем обрезанное изображение (по умолчанию cropped_size)
107
  if return_cropped:
108
+ cropped_final = crop_fruit_with_white_bg(orig_np, mask, out_size=cropped_size)
109
  pil_img = Image.fromarray(cropped_final)
110
  buffered = io.BytesIO()
111
  pil_img.save(buffered, format="PNG")
112
  result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
113
+ result["cropped_size"] = f"{cropped_size}x{cropped_size}"
114
 
115
  return result
models.py CHANGED
@@ -1,23 +1,22 @@
1
  import torch
2
  import torchvision.models as models
3
  import torch.nn as nn
4
- import segmentation_models_pytorch as smp
5
- from mobile_sam import sam_model_registry, SamPredictor
6
 
7
  DEVICE = torch.device('cpu')
8
 
9
- model1 = None # теперь это MobileSAM
10
- model2 = None # сорт фрукта
11
- model3 = None # свежесть
12
 
13
- def load_model1(weights_path='weights/mobile_sam.pt'):
14
- global model1
15
- if model1 is None:
16
  model_type = "vit_t"
17
- model1 = sam_model_registry[model_type](checkpoint=weights_path)
18
- model1.to(DEVICE)
19
- model1.eval()
20
- return model1
21
 
22
  def load_model2(weights_path='weights/class.pth'):
23
  global model2
 
1
  import torch
2
  import torchvision.models as models
3
  import torch.nn as nn
4
+ from mobile_sam import sam_model_registry
 
5
 
6
  DEVICE = torch.device('cpu')
7
 
8
+ sam_model = None # MobileSAM
9
+ model2 = None # сорт фрукта
10
+ model3 = None # свежесть
11
 
12
+ def load_sam(weights_path='weights/mobile_sam.pt'):
13
+ global sam_model
14
+ if sam_model is None:
15
  model_type = "vit_t"
16
+ sam_model = sam_model_registry[model_type](checkpoint=weights_path)
17
+ sam_model.to(DEVICE)
18
+ sam_model.eval()
19
+ return sam_model
20
 
21
  def load_model2(weights_path='weights/class.pth'):
22
  global model2
utils.py CHANGED
@@ -1,13 +1,11 @@
1
  import numpy as np
2
  import cv2
3
- import torch
4
  from PIL import Image
5
  import io
6
  import base64
 
7
  from torchvision import transforms
8
- from mobile_sam import SamPredictor
9
 
10
- # Константы
11
  FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
12
  FRESHNESS_CLASSES = ['freshapples', 'freshbanana', 'freshoranges', 'rottenapples', 'rottenbanana', 'rottenoranges']
13
 
@@ -19,7 +17,7 @@ def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
19
  ])
20
  return transform(img)
21
 
22
- # Универсальный letterbox (без искажения)
23
  def letterbox_any_size(
24
  img: np.ndarray,
25
  target_size: int = 224,
@@ -45,11 +43,10 @@ def letterbox_any_size(
45
  # Обрезка по маске SAM + белый фон + letterbox
46
  def crop_fruit_with_white_bg(
47
  orig_img: np.ndarray, # RGB
48
- mask: np.ndarray, # bool или uint8 от SAM
49
  out_size: int = 224,
50
  bg_color: tuple = (255, 255, 255)
51
  ) -> np.ndarray:
52
- # Маска → binary
53
  mask_bin = mask.astype(np.uint8)
54
 
55
  ys, xs = np.where(mask_bin == 1)
@@ -62,5 +59,4 @@ def crop_fruit_with_white_bg(
62
  cropped = orig_img[y1:y2+1, x1:x2+1].copy()
63
 
64
  final = letterbox_any_size(cropped, target_size=out_size, bg_color=bg_color)
65
-
66
  return final
 
1
  import numpy as np
2
  import cv2
 
3
  from PIL import Image
4
  import io
5
  import base64
6
+ import torch
7
  from torchvision import transforms
 
8
 
 
9
  FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
10
  FRESHNESS_CLASSES = ['freshapples', 'freshbanana', 'freshoranges', 'rottenapples', 'rottenbanana', 'rottenoranges']
11
 
 
17
  ])
18
  return transform(img)
19
 
20
+ # Letterbox без искажения пропорций
21
  def letterbox_any_size(
22
  img: np.ndarray,
23
  target_size: int = 224,
 
43
  # Обрезка по маске SAM + белый фон + letterbox
44
  def crop_fruit_with_white_bg(
45
  orig_img: np.ndarray, # RGB
46
+ mask: np.ndarray, # bool от SAM
47
  out_size: int = 224,
48
  bg_color: tuple = (255, 255, 255)
49
  ) -> np.ndarray:
 
50
  mask_bin = mask.astype(np.uint8)
51
 
52
  ys, xs = np.where(mask_bin == 1)
 
59
  cropped = orig_img[y1:y2+1, x1:x2+1].copy()
60
 
61
  final = letterbox_any_size(cropped, target_size=out_size, bg_color=bg_color)
 
62
  return final