mobilesam v1.1
Browse files
app.py
CHANGED
|
@@ -2,9 +2,9 @@ from fastapi import FastAPI, UploadFile, File, Query
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
-
import base64
|
| 6 |
import io
|
| 7 |
-
|
|
|
|
| 8 |
from utils import (
|
| 9 |
crop_fruit_with_white_bg,
|
| 10 |
preprocess_for_classifier,
|
|
@@ -15,9 +15,9 @@ from utils import (
|
|
| 15 |
app = FastAPI()
|
| 16 |
|
| 17 |
# Загрузка моделей
|
| 18 |
-
|
| 19 |
-
model2 = load_model2()
|
| 20 |
-
model3 = load_model3()
|
| 21 |
|
| 22 |
DEVICE = torch.device('cpu')
|
| 23 |
|
|
@@ -32,23 +32,22 @@ async def predict_full(
|
|
| 32 |
file: UploadFile = File(...),
|
| 33 |
point_x: int = Query(..., description="X-координата точки на фрукте (в пикселях оригинального изображения)"),
|
| 34 |
point_y: int = Query(..., description="Y-координата точки на фрукте"),
|
| 35 |
-
return_cropped: bool = Query(default=True, description="Вернуть обрезанное изображение в base64?")
|
|
|
|
| 36 |
):
|
| 37 |
content = await file.read()
|
| 38 |
image = Image.open(io.BytesIO(content)).convert('RGB')
|
| 39 |
orig_np = np.array(image)
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
sam_predictor.set_image(orig_np)
|
| 43 |
-
|
| 44 |
-
# Промпт: точка на фрукте
|
| 45 |
input_point = np.array([[point_x, point_y]])
|
| 46 |
input_label = np.array([1]) # 1 = foreground
|
| 47 |
|
| 48 |
-
masks, scores, _ =
|
|
|
|
| 49 |
point_coords=input_point,
|
| 50 |
point_labels=input_label,
|
| 51 |
-
multimask_output=False
|
| 52 |
)
|
| 53 |
|
| 54 |
# Берём лучшую маску
|
|
@@ -104,13 +103,13 @@ async def predict_full(
|
|
| 104 |
result["freshness"] = fresh_name
|
| 105 |
result["freshness_confidence"] = round(fresh_conf, 4)
|
| 106 |
|
| 107 |
-
# Возвращаем обрезанное изображение (по умолчанию
|
| 108 |
if return_cropped:
|
| 109 |
-
cropped_final = crop_fruit_with_white_bg(orig_np, mask, out_size=
|
| 110 |
pil_img = Image.fromarray(cropped_final)
|
| 111 |
buffered = io.BytesIO()
|
| 112 |
pil_img.save(buffered, format="PNG")
|
| 113 |
result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 114 |
-
result["cropped_size"] = "
|
| 115 |
|
| 116 |
return result
|
|
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
|
|
|
| 5 |
import io
|
| 6 |
+
import base64
|
| 7 |
+
from models import load_sam, load_model2, load_model3
|
| 8 |
from utils import (
|
| 9 |
crop_fruit_with_white_bg,
|
| 10 |
preprocess_for_classifier,
|
|
|
|
| 15 |
app = FastAPI()
|
| 16 |
|
| 17 |
# Загрузка моделей
|
| 18 |
+
sam_model = load_sam() # MobileSAM
|
| 19 |
+
model2 = load_model2() # сорт
|
| 20 |
+
model3 = load_model3() # свежесть
|
| 21 |
|
| 22 |
DEVICE = torch.device('cpu')
|
| 23 |
|
|
|
|
| 32 |
file: UploadFile = File(...),
|
| 33 |
point_x: int = Query(..., description="X-координата точки на фрукте (в пикселях оригинального изображения)"),
|
| 34 |
point_y: int = Query(..., description="Y-координата точки на фрукте"),
|
| 35 |
+
return_cropped: bool = Query(default=True, description="Вернуть обрезанное изображение в base64?"),
|
| 36 |
+
cropped_size: int = Query(224, description="Размер обрезанного изображения (100 или 224)")
|
| 37 |
):
|
| 38 |
content = await file.read()
|
| 39 |
image = Image.open(io.BytesIO(content)).convert('RGB')
|
| 40 |
orig_np = np.array(image)
|
| 41 |
|
| 42 |
+
# MobileSAM: сегментация по точке
|
|
|
|
|
|
|
|
|
|
| 43 |
input_point = np.array([[point_x, point_y]])
|
| 44 |
input_label = np.array([1]) # 1 = foreground
|
| 45 |
|
| 46 |
+
masks, scores, _ = sam_model.predict(
|
| 47 |
+
image=orig_np,
|
| 48 |
point_coords=input_point,
|
| 49 |
point_labels=input_label,
|
| 50 |
+
multimask_output=False
|
| 51 |
)
|
| 52 |
|
| 53 |
# Берём лучшую маску
|
|
|
|
| 103 |
result["freshness"] = fresh_name
|
| 104 |
result["freshness_confidence"] = round(fresh_conf, 4)
|
| 105 |
|
| 106 |
+
# Возвращаем обрезанное изображение (по умолчанию cropped_size)
|
| 107 |
if return_cropped:
|
| 108 |
+
cropped_final = crop_fruit_with_white_bg(orig_np, mask, out_size=cropped_size)
|
| 109 |
pil_img = Image.fromarray(cropped_final)
|
| 110 |
buffered = io.BytesIO()
|
| 111 |
pil_img.save(buffered, format="PNG")
|
| 112 |
result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 113 |
+
result["cropped_size"] = f"{cropped_size}x{cropped_size}"
|
| 114 |
|
| 115 |
return result
|
models.py
CHANGED
|
@@ -1,23 +1,22 @@
|
|
| 1 |
import torch
|
| 2 |
import torchvision.models as models
|
| 3 |
import torch.nn as nn
|
| 4 |
-
|
| 5 |
-
from mobile_sam import sam_model_registry, SamPredictor
|
| 6 |
|
| 7 |
DEVICE = torch.device('cpu')
|
| 8 |
|
| 9 |
-
|
| 10 |
-
model2 = None
|
| 11 |
-
model3 = None
|
| 12 |
|
| 13 |
-
def
|
| 14 |
-
global
|
| 15 |
-
if
|
| 16 |
model_type = "vit_t"
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
return
|
| 21 |
|
| 22 |
def load_model2(weights_path='weights/class.pth'):
|
| 23 |
global model2
|
|
|
|
| 1 |
import torch
|
| 2 |
import torchvision.models as models
|
| 3 |
import torch.nn as nn
|
| 4 |
+
from mobile_sam import sam_model_registry
|
|
|
|
| 5 |
|
| 6 |
DEVICE = torch.device('cpu')
|
| 7 |
|
| 8 |
+
sam_model = None # MobileSAM
|
| 9 |
+
model2 = None # сорт фрукта
|
| 10 |
+
model3 = None # свежесть
|
| 11 |
|
| 12 |
+
def load_sam(weights_path='weights/mobile_sam.pt'):
|
| 13 |
+
global sam_model
|
| 14 |
+
if sam_model is None:
|
| 15 |
model_type = "vit_t"
|
| 16 |
+
sam_model = sam_model_registry[model_type](checkpoint=weights_path)
|
| 17 |
+
sam_model.to(DEVICE)
|
| 18 |
+
sam_model.eval()
|
| 19 |
+
return sam_model
|
| 20 |
|
| 21 |
def load_model2(weights_path='weights/class.pth'):
|
| 22 |
global model2
|
utils.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import cv2
|
| 3 |
-
import torch
|
| 4 |
from PIL import Image
|
| 5 |
import io
|
| 6 |
import base64
|
|
|
|
| 7 |
from torchvision import transforms
|
| 8 |
-
from mobile_sam import SamPredictor
|
| 9 |
|
| 10 |
-
# Константы
|
| 11 |
FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
|
| 12 |
FRESHNESS_CLASSES = ['freshapples', 'freshbanana', 'freshoranges', 'rottenapples', 'rottenbanana', 'rottenoranges']
|
| 13 |
|
|
@@ -19,7 +17,7 @@ def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
|
|
| 19 |
])
|
| 20 |
return transform(img)
|
| 21 |
|
| 22 |
-
#
|
| 23 |
def letterbox_any_size(
|
| 24 |
img: np.ndarray,
|
| 25 |
target_size: int = 224,
|
|
@@ -45,11 +43,10 @@ def letterbox_any_size(
|
|
| 45 |
# Обрезка по маске SAM + белый фон + letterbox
|
| 46 |
def crop_fruit_with_white_bg(
|
| 47 |
orig_img: np.ndarray, # RGB
|
| 48 |
-
mask: np.ndarray, # bool
|
| 49 |
out_size: int = 224,
|
| 50 |
bg_color: tuple = (255, 255, 255)
|
| 51 |
) -> np.ndarray:
|
| 52 |
-
# Маска → binary
|
| 53 |
mask_bin = mask.astype(np.uint8)
|
| 54 |
|
| 55 |
ys, xs = np.where(mask_bin == 1)
|
|
@@ -62,5 +59,4 @@ def crop_fruit_with_white_bg(
|
|
| 62 |
cropped = orig_img[y1:y2+1, x1:x2+1].copy()
|
| 63 |
|
| 64 |
final = letterbox_any_size(cropped, target_size=out_size, bg_color=bg_color)
|
| 65 |
-
|
| 66 |
return final
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import cv2
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
import io
|
| 5 |
import base64
|
| 6 |
+
import torch
|
| 7 |
from torchvision import transforms
|
|
|
|
| 8 |
|
|
|
|
| 9 |
FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
|
| 10 |
FRESHNESS_CLASSES = ['freshapples', 'freshbanana', 'freshoranges', 'rottenapples', 'rottenbanana', 'rottenoranges']
|
| 11 |
|
|
|
|
| 17 |
])
|
| 18 |
return transform(img)
|
| 19 |
|
| 20 |
+
# Letterbox без искажения пропорций
|
| 21 |
def letterbox_any_size(
|
| 22 |
img: np.ndarray,
|
| 23 |
target_size: int = 224,
|
|
|
|
| 43 |
# Обрезка по маске SAM + белый фон + letterbox
|
| 44 |
def crop_fruit_with_white_bg(
|
| 45 |
orig_img: np.ndarray, # RGB
|
| 46 |
+
mask: np.ndarray, # bool от SAM
|
| 47 |
out_size: int = 224,
|
| 48 |
bg_color: tuple = (255, 255, 255)
|
| 49 |
) -> np.ndarray:
|
|
|
|
| 50 |
mask_bin = mask.astype(np.uint8)
|
| 51 |
|
| 52 |
ys, xs = np.where(mask_bin == 1)
|
|
|
|
| 59 |
cropped = orig_img[y1:y2+1, x1:x2+1].copy()
|
| 60 |
|
| 61 |
final = letterbox_any_size(cropped, target_size=out_size, bg_color=bg_color)
|
|
|
|
| 62 |
return final
|