ivanm151 commited on
Commit
fbb0759
·
1 Parent(s): 9dd116b

mobilesam v1.4

Browse files
Files changed (3) hide show
  1. app.py +9 -7
  2. models.py +11 -10
  3. utils.py +2 -4
app.py CHANGED
@@ -15,9 +15,9 @@ from utils import (
15
  app = FastAPI()
16
 
17
  # Загрузка моделей
18
- sam_model = load_sam() # MobileSAM
19
- model2 = load_model2() # сорт
20
- model3 = load_model3() # свежесть
21
 
22
  DEVICE = torch.device('cpu')
23
 
@@ -39,13 +39,15 @@ async def predict_full(
39
  image = Image.open(io.BytesIO(content)).convert('RGB')
40
  orig_np = np.array(image)
41
 
42
- # MobileSAM: сегментация по точке
 
 
 
43
  input_point = np.array([[point_x, point_y]])
44
  input_label = np.array([1]) # 1 = foreground
45
 
46
- # Правильный вызов MobileSAM (из официального примера)
47
- masks, scores, logits = sam_model.predict(
48
- img=orig_np,
49
  point_coords=input_point,
50
  point_labels=input_label,
51
  multimask_output=False
 
15
  app = FastAPI()
16
 
17
  # Загрузка моделей
18
+ sam_predictor = load_sam() # MobileSAM + SamPredictor
19
+ model2 = load_model2()
20
+ model3 = load_model3()
21
 
22
  DEVICE = torch.device('cpu')
23
 
 
39
  image = Image.open(io.BytesIO(content)).convert('RGB')
40
  orig_np = np.array(image)
41
 
42
+ # Установка изображения в MobileSAM Predictor
43
+ sam_predictor.set_image(orig_np)
44
+
45
+ # Промпт: точка на фрукте
46
  input_point = np.array([[point_x, point_y]])
47
  input_label = np.array([1]) # 1 = foreground
48
 
49
+ # Предсказание маски
50
+ masks, scores, logits = sam_predictor.predict(
 
51
  point_coords=input_point,
52
  point_labels=input_label,
53
  multimask_output=False
models.py CHANGED
@@ -1,22 +1,23 @@
1
  import torch
2
  import torchvision.models as models
3
  import torch.nn as nn
4
- from mobile_sam import sam_model_registry
5
 
6
  DEVICE = torch.device('cpu')
7
 
8
- sam_model = None # MobileSAM
9
- model2 = None # сорт фрукта
10
- model3 = None # свежесть
11
 
12
  def load_sam(weights_path='weights/mobile_sam.pt'):
13
- global sam_model
14
- if sam_model is None:
15
  model_type = "vit_t"
16
- sam_model = sam_model_registry[model_type](checkpoint=weights_path)
17
- sam_model.to(DEVICE)
18
- sam_model.eval()
19
- return sam_model
 
20
 
21
  def load_model2(weights_path='weights/class.pth'):
22
  global model2
 
1
  import torch
2
  import torchvision.models as models
3
  import torch.nn as nn
4
+ from mobile_sam import sam_model_registry, SamPredictor
5
 
6
  DEVICE = torch.device('cpu')
7
 
8
+ sam_predictor = None # MobileSAM + Predictor
9
+ model2 = None # сорт фрукта
10
+ model3 = None # свежесть
11
 
12
  def load_sam(weights_path='weights/mobile_sam.pt'):
13
+ global sam_predictor
14
+ if sam_predictor is None:
15
  model_type = "vit_t"
16
+ sam = sam_model_registry[model_type](checkpoint=weights_path)
17
+ sam.to(DEVICE)
18
+ sam.eval()
19
+ sam_predictor = SamPredictor(sam)
20
+ return sam_predictor
21
 
22
  def load_model2(weights_path='weights/class.pth'):
23
  global model2
utils.py CHANGED
@@ -17,7 +17,6 @@ def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
17
  ])
18
  return transform(img)
19
 
20
- # Letterbox без искажения
21
  def letterbox_any_size(
22
  img: np.ndarray,
23
  target_size: int = 224,
@@ -40,10 +39,9 @@ def letterbox_any_size(
40
  cv2.BORDER_CONSTANT, value=bg_color)
41
  return padded
42
 
43
- # Обрезка по маске SAM + белый фон + letterbox
44
  def crop_fruit_with_white_bg(
45
- orig_img: np.ndarray, # RGB
46
- mask: np.ndarray, # bool от SAM
47
  out_size: int = 224,
48
  bg_color: tuple = (255, 255, 255)
49
  ) -> np.ndarray:
 
17
  ])
18
  return transform(img)
19
 
 
20
  def letterbox_any_size(
21
  img: np.ndarray,
22
  target_size: int = 224,
 
39
  cv2.BORDER_CONSTANT, value=bg_color)
40
  return padded
41
 
 
42
  def crop_fruit_with_white_bg(
43
+ orig_img: np.ndarray,
44
+ mask: np.ndarray,
45
  out_size: int = 224,
46
  bg_color: tuple = (255, 255, 255)
47
  ) -> np.ndarray: