mobilesam v1.0
Browse files- app.py +45 -28
- models.py +6 -11
- requirements.txt +2 -1
- utils.py +13 -102
- weights/{seg0.pth β mobile_sam.pt} +2 -2
app.py
CHANGED
|
@@ -1,30 +1,26 @@
|
|
| 1 |
from fastapi import FastAPI, UploadFile, File, Query
|
| 2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from models import load_model1, load_model2, load_model3
|
| 4 |
from utils import (
|
| 5 |
-
|
| 6 |
-
predict_mask_tta,
|
| 7 |
-
postprocess_mask,
|
| 8 |
-
mask_to_base64,
|
| 9 |
-
apply_white_background_and_crop,
|
| 10 |
preprocess_for_classifier,
|
| 11 |
FRUIT_CLASSES,
|
| 12 |
FRESHNESS_CLASSES
|
| 13 |
)
|
| 14 |
-
import numpy as np
|
| 15 |
-
from PIL import Image
|
| 16 |
-
import io
|
| 17 |
|
| 18 |
app = FastAPI()
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
model2 = load_model2()
|
| 23 |
-
model3 = load_model3()
|
| 24 |
|
| 25 |
DEVICE = torch.device('cpu')
|
| 26 |
|
| 27 |
-
# ΠΠ»Π°ΡΡΡ, Π΄Π»Ρ ΠΊΠΎΡΠΎΡΡΡ
Π΄Π΅Π»Π°Π΅ΠΌ ΡΠ²Π΅ΠΆΠ΅ΡΡΡ
|
| 28 |
FRESHNESS_ELIGIBLE = {'apple', 'banana', 'orange', 'lemon'}
|
| 29 |
|
| 30 |
@app.get("/")
|
|
@@ -34,19 +30,33 @@ def greet_json():
|
|
| 34 |
@app.post("/predict_full")
|
| 35 |
async def predict_full(
|
| 36 |
file: UploadFile = File(...),
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
):
|
| 39 |
content = await file.read()
|
| 40 |
image = Image.open(io.BytesIO(content)).convert('RGB')
|
| 41 |
orig_np = np.array(image)
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
with torch.no_grad():
|
| 46 |
-
prob = predict_mask_tta(model1, input_tensor)
|
| 47 |
-
mask = postprocess_mask(prob.squeeze().cpu().numpy())
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
if fruit_area_ratio < 0.01:
|
| 51 |
return {
|
| 52 |
"status": "no_fruit_detected",
|
|
@@ -55,12 +65,11 @@ async def predict_full(
|
|
| 55 |
"fruit_confidence": None,
|
| 56 |
"freshness": None,
|
| 57 |
"freshness_confidence": None,
|
| 58 |
-
"
|
| 59 |
}
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
cropped_100 =
|
| 63 |
-
|
| 64 |
input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
|
| 65 |
with torch.no_grad():
|
| 66 |
logits2 = model2(input_tensor2)
|
|
@@ -77,13 +86,12 @@ async def predict_full(
|
|
| 77 |
"fruit_confidence": round(fruit_conf, 4),
|
| 78 |
"freshness": None,
|
| 79 |
"freshness_confidence": None,
|
| 80 |
-
"
|
| 81 |
}
|
| 82 |
|
| 83 |
-
#
|
| 84 |
if fruit_name in FRESHNESS_ELIGIBLE:
|
| 85 |
-
cropped_224 =
|
| 86 |
-
|
| 87 |
input_tensor3 = preprocess_for_classifier(cropped_224).unsqueeze(0).to(DEVICE)
|
| 88 |
with torch.no_grad():
|
| 89 |
logits3 = model3(input_tensor3)
|
|
@@ -96,4 +104,13 @@ async def predict_full(
|
|
| 96 |
result["freshness"] = fresh_name
|
| 97 |
result["freshness_confidence"] = round(fresh_conf, 4)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
return result
|
|
|
|
| 1 |
from fastapi import FastAPI, UploadFile, File, Query
|
| 2 |
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import base64
|
| 6 |
+
import io
|
| 7 |
from models import load_model1, load_model2, load_model3
|
| 8 |
from utils import (
|
| 9 |
+
crop_fruit_with_white_bg,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
preprocess_for_classifier,
|
| 11 |
FRUIT_CLASSES,
|
| 12 |
FRESHNESS_CLASSES
|
| 13 |
)
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
app = FastAPI()
|
| 16 |
|
| 17 |
+
# ΠΠ°Π³ΡΡΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»Π΅ΠΉ
|
| 18 |
+
sam_predictor = load_model1() # MobileSAM
|
| 19 |
+
model2 = load_model2()
|
| 20 |
+
model3 = load_model3()
|
| 21 |
|
| 22 |
DEVICE = torch.device('cpu')
|
| 23 |
|
|
|
|
| 24 |
FRESHNESS_ELIGIBLE = {'apple', 'banana', 'orange', 'lemon'}
|
| 25 |
|
| 26 |
@app.get("/")
|
|
|
|
| 30 |
@app.post("/predict_full")
|
| 31 |
async def predict_full(
|
| 32 |
file: UploadFile = File(...),
|
| 33 |
+
point_x: int = Query(..., description="X-ΠΊΠΎΠΎΡΠ΄ΠΈΠ½Π°ΡΠ° ΡΠΎΡΠΊΠΈ Π½Π° ΡΡΡΠΊΡΠ΅ (Π² ΠΏΠΈΠΊΡΠ΅Π»ΡΡ
ΠΎΡΠΈΠ³ΠΈΠ½Π°Π»ΡΠ½ΠΎΠ³ΠΎ ΠΈΠ·ΠΎΠ±ΡΠ°ΠΆΠ΅Π½ΠΈΡ)"),
|
| 34 |
+
point_y: int = Query(..., description="Y-ΠΊΠΎΠΎΡΠ΄ΠΈΠ½Π°ΡΠ° ΡΠΎΡΠΊΠΈ Π½Π° ΡΡΡΠΊΡΠ΅"),
|
| 35 |
+
return_cropped: bool = Query(default=True, description="ΠΠ΅ΡΠ½ΡΡΡ ΠΎΠ±ΡΠ΅Π·Π°Π½Π½ΠΎΠ΅ ΠΈΠ·ΠΎΠ±ΡΠ°ΠΆΠ΅Π½ΠΈΠ΅ Π² base64?")
|
| 36 |
):
|
| 37 |
content = await file.read()
|
| 38 |
image = Image.open(io.BytesIO(content)).convert('RGB')
|
| 39 |
orig_np = np.array(image)
|
| 40 |
|
| 41 |
+
# Π£ΡΡΠ°Π½ΠΎΠ²ΠΊΠ° ΠΈΠ·ΠΎΠ±ΡΠ°ΠΆΠ΅Π½ΠΈΡ Π² SAM
|
| 42 |
+
sam_predictor.set_image(orig_np)
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
# ΠΡΠΎΠΌΠΏΡ: ΡΠΎΡΠΊΠ° Π½Π° ΡΡΡΠΊΡΠ΅
|
| 45 |
+
input_point = np.array([[point_x, point_y]])
|
| 46 |
+
input_label = np.array([1]) # 1 = foreground
|
| 47 |
+
|
| 48 |
+
masks, scores, _ = sam_predictor.predict(
|
| 49 |
+
point_coords=input_point,
|
| 50 |
+
point_labels=input_label,
|
| 51 |
+
multimask_output=False # ΠΠ΄Π½Π° ΠΌΠ°ΡΠΊΠ°
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# ΠΠ΅ΡΡΠΌ Π»ΡΡΡΡΡ ΠΌΠ°ΡΠΊΡ
|
| 55 |
+
best_mask_idx = np.argmax(scores)
|
| 56 |
+
mask = masks[best_mask_idx] # bool
|
| 57 |
+
|
| 58 |
+
# ΠΡΠΎΠ²Π΅ΡΠΊΠ°: Π΅ΡΡΡ Π»ΠΈ ΡΡΡΠΊΡ?
|
| 59 |
+
fruit_area_ratio = np.mean(mask)
|
| 60 |
if fruit_area_ratio < 0.01:
|
| 61 |
return {
|
| 62 |
"status": "no_fruit_detected",
|
|
|
|
| 65 |
"fruit_confidence": None,
|
| 66 |
"freshness": None,
|
| 67 |
"freshness_confidence": None,
|
| 68 |
+
"cropped_base64": None
|
| 69 |
}
|
| 70 |
|
| 71 |
+
# ΠΠ±ΡΠ΅Π·ΠΊΠ° ΠΏΠΎΠ΄ 100Γ100 Π΄Π»Ρ ΡΠΎΡΡΠ°
|
| 72 |
+
cropped_100 = crop_fruit_with_white_bg(orig_np, mask, out_size=100)
|
|
|
|
| 73 |
input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
|
| 74 |
with torch.no_grad():
|
| 75 |
logits2 = model2(input_tensor2)
|
|
|
|
| 86 |
"fruit_confidence": round(fruit_conf, 4),
|
| 87 |
"freshness": None,
|
| 88 |
"freshness_confidence": None,
|
| 89 |
+
"cropped_base64": None
|
| 90 |
}
|
| 91 |
|
| 92 |
+
# Π‘Π²Π΅ΠΆΠ΅ΡΡΡ, Π΅ΡΠ»ΠΈ ΠΏΠΎΠ΄Ρ
ΠΎΠ΄ΠΈΡ
|
| 93 |
if fruit_name in FRESHNESS_ELIGIBLE:
|
| 94 |
+
cropped_224 = crop_fruit_with_white_bg(orig_np, mask, out_size=224)
|
|
|
|
| 95 |
input_tensor3 = preprocess_for_classifier(cropped_224).unsqueeze(0).to(DEVICE)
|
| 96 |
with torch.no_grad():
|
| 97 |
logits3 = model3(input_tensor3)
|
|
|
|
| 104 |
result["freshness"] = fresh_name
|
| 105 |
result["freshness_confidence"] = round(fresh_conf, 4)
|
| 106 |
|
| 107 |
+
# ΠΠΎΠ·Π²ΡΠ°ΡΠ°Π΅ΠΌ ΠΎΠ±ΡΠ΅Π·Π°Π½Π½ΠΎΠ΅ ΠΈΠ·ΠΎΠ±ΡΠ°ΠΆΠ΅Π½ΠΈΠ΅ (ΠΏΠΎ ΡΠΌΠΎΠ»ΡΠ°Π½ΠΈΡ 224Γ224)
|
| 108 |
+
if return_cropped:
|
| 109 |
+
cropped_final = crop_fruit_with_white_bg(orig_np, mask, out_size=224)
|
| 110 |
+
pil_img = Image.fromarray(cropped_final)
|
| 111 |
+
buffered = io.BytesIO()
|
| 112 |
+
pil_img.save(buffered, format="PNG")
|
| 113 |
+
result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 114 |
+
result["cropped_size"] = "224x224"
|
| 115 |
+
|
| 116 |
return result
|
models.py
CHANGED
|
@@ -2,25 +2,20 @@ import torch
|
|
| 2 |
import torchvision.models as models
|
| 3 |
import torch.nn as nn
|
| 4 |
import segmentation_models_pytorch as smp
|
|
|
|
| 5 |
|
| 6 |
DEVICE = torch.device('cpu')
|
| 7 |
|
| 8 |
-
model1 = None #
|
| 9 |
model2 = None # ΡΠΎΡΡ ΡΡΡΠΊΡΠ°
|
| 10 |
model3 = None # ΡΠ²Π΅ΠΆΠ΅ΡΡΡ
|
| 11 |
|
| 12 |
-
def load_model1(weights_path='weights/
|
| 13 |
global model1
|
| 14 |
if model1 is None:
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
in_channels=3,
|
| 19 |
-
classes=1,
|
| 20 |
-
activation=None
|
| 21 |
-
).to(DEVICE)
|
| 22 |
-
state_dict = torch.load(weights_path, map_location=DEVICE)
|
| 23 |
-
model1.load_state_dict(state_dict)
|
| 24 |
model1.eval()
|
| 25 |
return model1
|
| 26 |
|
|
|
|
| 2 |
import torchvision.models as models
|
| 3 |
import torch.nn as nn
|
| 4 |
import segmentation_models_pytorch as smp
|
| 5 |
+
from mobile_sam import sam_model_registry, SamPredictor
|
| 6 |
|
| 7 |
DEVICE = torch.device('cpu')
|
| 8 |
|
| 9 |
+
model1 = None # ΡΠ΅ΠΏΠ΅ΡΡ ΡΡΠΎ MobileSAM
|
| 10 |
model2 = None # ΡΠΎΡΡ ΡΡΡΠΊΡΠ°
|
| 11 |
model3 = None # ΡΠ²Π΅ΠΆΠ΅ΡΡΡ
|
| 12 |
|
| 13 |
+
def load_model1(weights_path='weights/mobile_sam.pt'):
|
| 14 |
global model1
|
| 15 |
if model1 is None:
|
| 16 |
+
model_type = "vit_t"
|
| 17 |
+
model1 = sam_model_registry[model_type](checkpoint=weights_path)
|
| 18 |
+
model1.to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
model1.eval()
|
| 20 |
return model1
|
| 21 |
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ albumentations
|
|
| 7 |
pillow
|
| 8 |
numpy
|
| 9 |
opencv-python-headless
|
| 10 |
-
python-multipart
|
|
|
|
|
|
| 7 |
pillow
|
| 8 |
numpy
|
| 9 |
opencv-python-headless
|
| 10 |
+
python-multipart
|
| 11 |
+
git+https://github.com/ChaoningZhang/MobileSAM.git
|
utils.py
CHANGED
|
@@ -1,100 +1,16 @@
|
|
| 1 |
import numpy as np
|
| 2 |
-
import albumentations as A
|
| 3 |
-
from albumentations.pytorch import ToTensorV2
|
| 4 |
-
import torch
|
| 5 |
import cv2
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
import io
|
| 8 |
import base64
|
| 9 |
from torchvision import transforms
|
|
|
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
# ΠΠΎΠ²ΡΠΉ ΡΠ°Π·ΠΌΠ΅Ρ Π²Ρ
ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ β 448Γ448
|
| 13 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
-
IMG_SIZE = 448
|
| 15 |
-
|
| 16 |
-
preprocess_transform = A.Compose([
|
| 17 |
-
A.Resize(IMG_SIZE, IMG_SIZE),
|
| 18 |
-
A.Normalize(), # mean/std ImageNet β ΡΠΎ ΠΆΠ΅, ΡΡΠΎ ΠΈ Π² ΠΎΠ±ΡΡΠ΅Π½ΠΈΠΈ
|
| 19 |
-
ToTensorV2()
|
| 20 |
-
])
|
| 21 |
-
|
| 22 |
-
def preprocess_image(image_np: np.ndarray) -> torch.Tensor:
|
| 23 |
-
augmented = preprocess_transform(image=image_np)
|
| 24 |
-
return augmented['image']
|
| 25 |
-
|
| 26 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
-
# TTA-ΠΏΡΠ΅Π΄ΡΠΊΠ°Π·Π°Π½ΠΈΠ΅ (ΠΊΠ°ΠΊ Π² ΡΠ²ΠΎΡΠΌ ΠΏΡΠΈΠΌΠ΅ΡΠ΅)
|
| 28 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
-
@torch.no_grad()
|
| 30 |
-
def predict_mask_tta(model, image_tensor):
|
| 31 |
-
preds = []
|
| 32 |
-
# ΠΡΠΈΠ³ΠΈΠ½Π°Π»
|
| 33 |
-
preds.append(torch.sigmoid(model(image_tensor)))
|
| 34 |
-
# Flip horizontal
|
| 35 |
-
preds.append(
|
| 36 |
-
torch.flip(
|
| 37 |
-
torch.sigmoid(model(torch.flip(image_tensor, dims=[3]))),
|
| 38 |
-
dims=[3]
|
| 39 |
-
)
|
| 40 |
-
)
|
| 41 |
-
# Flip vertical
|
| 42 |
-
preds.append(
|
| 43 |
-
torch.flip(
|
| 44 |
-
torch.sigmoid(model(torch.flip(image_tensor, dims=[2]))),
|
| 45 |
-
dims=[2]
|
| 46 |
-
)
|
| 47 |
-
)
|
| 48 |
-
return torch.mean(torch.stack(preds), dim=0)
|
| 49 |
-
|
| 50 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
-
# Post-processing ΠΌΠ°ΡΠΊΠΈ (ΠΊΠ°ΠΊ Π² ΡΠ²ΠΎΡΠΌ ΠΏΡΠΈΠΌΠ΅ΡΠ΅ + ΠΌΠΎΡΡΠΎΠ»ΠΎΠ³ΠΈΡ)
|
| 52 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
-
def postprocess_mask(prob: np.ndarray, threshold: float = 0.65, min_area_ratio: float = 0.01) -> np.ndarray:
|
| 54 |
-
binary = (prob > threshold).astype(np.uint8)
|
| 55 |
-
|
| 56 |
-
# Connected components β ΠΎΡΡΠ°Π²Π»ΡΠ΅ΠΌ ΡΠΎΠ»ΡΠΊΠΎ Π³Π»Π°Π²Π½ΡΠΉ ΠΎΠ±ΡΠ΅ΠΊΡ
|
| 57 |
-
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 58 |
-
|
| 59 |
-
if num_labels <= 1:
|
| 60 |
-
return binary.astype(np.float32)
|
| 61 |
-
|
| 62 |
-
largest_label = np.argmax(stats[1:, cv2.CC_STAT_AREA]) + 1
|
| 63 |
-
area = stats[largest_label, cv2.CC_STAT_AREA]
|
| 64 |
-
|
| 65 |
-
if area < binary.shape[0] * binary.shape[1] * min_area_ratio:
|
| 66 |
-
return np.zeros_like(binary, dtype=np.float32)
|
| 67 |
-
|
| 68 |
-
clean_mask = (labels == largest_label).astype(np.float32)
|
| 69 |
-
|
| 70 |
-
# ΠΠΎΡΡΠΎΠ»ΠΎΠ³ΠΈΡ (Π·Π°ΠΏΠΎΠ»Π½ΠΈΡΡ Π΄ΡΡΠΊΠΈ, ΡΠ±ΡΠ°ΡΡ ΡΡΠΌ)
|
| 71 |
-
kernel = np.ones((3, 3), np.uint8)
|
| 72 |
-
clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_CLOSE, kernel)
|
| 73 |
-
clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_OPEN, kernel)
|
| 74 |
-
|
| 75 |
-
return clean_mask
|
| 76 |
-
|
| 77 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
-
# Base64 ΠΌΠ°ΡΠΊΠΈ (Π΄Π»Ρ Π²ΠΎΠ·Π²ΡΠ°ΡΠ° ΠΊΠ»ΠΈΠ΅Π½ΡΡ)
|
| 79 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
-
def mask_to_base64(mask: np.ndarray) -> str:
|
| 81 |
-
pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
|
| 82 |
-
buffered = io.BytesIO()
|
| 83 |
-
pil_mask.save(buffered, format="PNG")
|
| 84 |
-
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 85 |
-
|
| 86 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 87 |
-
# ΠΠΎΠ½ΡΡΠ°Π½ΡΡ ΠΊΠ»Π°ΡΡΠΎΠ²
|
| 88 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
|
| 90 |
-
FRESHNESS_CLASSES = [
|
| 91 |
-
'freshapples', 'freshbanana', 'freshoranges',
|
| 92 |
-
'rottenapples', 'rottenbanana', 'rottenoranges'
|
| 93 |
-
]
|
| 94 |
|
| 95 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
-
# Preprocess Π΄Π»Ρ ΠΊΠ»Π°ΡΡΠΈΡΠΈΠΊΠ°ΡΠΎΡΠΎΠ² (100 ΠΈ 224)
|
| 97 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
|
| 99 |
transform = transforms.Compose([
|
| 100 |
transforms.ToPILImage(),
|
|
@@ -103,9 +19,7 @@ def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
|
|
| 103 |
])
|
| 104 |
return transform(img)
|
| 105 |
|
| 106 |
-
#
|
| 107 |
-
# Π£Π½ΠΈΠ²Π΅ΡΡΠ°Π»ΡΠ½ΡΠΉ letterbox (Π΄Π»Ρ Π»ΡΠ±ΠΎΠ³ΠΎ target_size)
|
| 108 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
def letterbox_any_size(
|
| 110 |
img: np.ndarray,
|
| 111 |
target_size: int = 224,
|
|
@@ -124,22 +38,19 @@ def letterbox_any_size(
|
|
| 124 |
left = pad_w // 2
|
| 125 |
right = pad_w - left
|
| 126 |
|
| 127 |
-
padded = cv2.copyMakeBorder(
|
| 128 |
-
|
| 129 |
-
cv2.BORDER_CONSTANT, value=bg_color
|
| 130 |
-
)
|
| 131 |
return padded
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
|
| 135 |
-
#
|
| 136 |
-
|
| 137 |
-
orig_img: np.ndarray, # RGB
|
| 138 |
-
mask: np.ndarray, # float [0,1] 448Γ448
|
| 139 |
out_size: int = 224,
|
| 140 |
bg_color: tuple = (255, 255, 255)
|
| 141 |
) -> np.ndarray:
|
| 142 |
-
|
|
|
|
| 143 |
|
| 144 |
ys, xs = np.where(mask_bin == 1)
|
| 145 |
if len(xs) == 0:
|
|
|
|
| 1 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 2 |
import cv2
|
| 3 |
+
import torch
|
| 4 |
from PIL import Image
|
| 5 |
import io
|
| 6 |
import base64
|
| 7 |
from torchvision import transforms
|
| 8 |
+
from mobile_sam import SamPredictor
|
| 9 |
|
| 10 |
+
# ΠΠΎΠ½ΡΡΠ°Π½ΡΡ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
FRUIT_CLASSES = ['apple', 'banana', 'orange', 'strawberry', 'pear', 'lemon', 'cucumber', 'plum', 'raspberry', 'watermelon']
|
| 12 |
+
FRESHNESS_CLASSES = ['freshapples', 'freshbanana', 'freshoranges', 'rottenapples', 'rottenbanana', 'rottenoranges']
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
def preprocess_for_classifier(img: np.ndarray) -> torch.Tensor:
|
| 15 |
transform = transforms.Compose([
|
| 16 |
transforms.ToPILImage(),
|
|
|
|
| 19 |
])
|
| 20 |
return transform(img)
|
| 21 |
|
| 22 |
+
# Π£Π½ΠΈΠ²Π΅ΡΡΠ°Π»ΡΠ½ΡΠΉ letterbox (Π±Π΅Π· ΠΈΡΠΊΠ°ΠΆΠ΅Π½ΠΈΡ)
|
|
|
|
|
|
|
| 23 |
def letterbox_any_size(
|
| 24 |
img: np.ndarray,
|
| 25 |
target_size: int = 224,
|
|
|
|
| 38 |
left = pad_w // 2
|
| 39 |
right = pad_w - left
|
| 40 |
|
| 41 |
+
padded = cv2.copyMakeBorder(resized, top, bottom, left, right,
|
| 42 |
+
cv2.BORDER_CONSTANT, value=bg_color)
|
|
|
|
|
|
|
| 43 |
return padded
|
| 44 |
|
| 45 |
+
# ΠΠ±ΡΠ΅Π·ΠΊΠ° ΠΏΠΎ ΠΌΠ°ΡΠΊΠ΅ SAM + Π±Π΅Π»ΡΠΉ ΡΠΎΠ½ + letterbox
|
| 46 |
+
def crop_fruit_with_white_bg(
|
| 47 |
+
orig_img: np.ndarray, # RGB
|
| 48 |
+
mask: np.ndarray, # bool ΠΈΠ»ΠΈ uint8 ΠΎΡ SAM
|
|
|
|
|
|
|
| 49 |
out_size: int = 224,
|
| 50 |
bg_color: tuple = (255, 255, 255)
|
| 51 |
) -> np.ndarray:
|
| 52 |
+
# ΠΠ°ΡΠΊΠ° β binary
|
| 53 |
+
mask_bin = mask.astype(np.uint8)
|
| 54 |
|
| 55 |
ys, xs = np.where(mask_bin == 1)
|
| 56 |
if len(xs) == 0:
|
weights/{seg0.pth β mobile_sam.pt}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f
|
| 3 |
+
size 40728226
|