ivanm151 commited on
Commit
1c6f885
·
1 Parent(s): 3dc4dee

add class1

Browse files
Files changed (4) hide show
  1. app.py +68 -12
  2. models.py +22 -3
  3. requirements.txt +3 -1
  4. utils.py +103 -6
app.py CHANGED
@@ -1,15 +1,28 @@
1
- from fastapi import FastAPI, UploadFile, File
2
  import torch
3
- from models import load_model1
4
- from utils import preprocess_image, postprocess_mask, resize_mask, mask_to_base64
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  from PIL import Image
7
  import io
8
 
9
  app = FastAPI()
10
 
11
- # Загрузка модели при старте (глобально, один раз)
12
- model1 = load_model1()
 
 
 
13
 
14
 
15
  @app.get("/")
@@ -19,20 +32,63 @@ def greet_json():
19
 
20
  @app.post("/predict1")
21
  async def predict1(file: UploadFile = File(...)):
 
 
 
22
  content = await file.read()
23
  image = Image.open(io.BytesIO(content)).convert('RGB')
24
  image_np = np.array(image)
25
 
26
- input_tensor = preprocess_image(image_np)
 
27
  with torch.no_grad():
28
- logits = model1(input_tensor.unsqueeze(0)) # batch dim
 
 
 
 
 
 
 
 
29
 
30
- pred_mask = postprocess_mask(logits) # (256, 256) binary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- mask_100 = resize_mask(pred_mask, 100)
33
- mask_224 = resize_mask(pred_mask, 224)
34
 
35
  return {
36
- "mask_100_base64": mask_to_base64(mask_100),
37
- "mask_224_base64": mask_to_base64(mask_224)
 
38
  }
 
1
+ from fastapi import FastAPI, UploadFile, File, Form
2
  import torch
3
+ from models import load_model1, load_model2
4
+ from utils import (
5
+ # Для /predict1
6
+ preprocess_image,
7
+ postprocess_mask,
8
+ mask_to_base64,
9
+ # Для /predict2
10
+ decode_base64_mask,
11
+ apply_mask_and_crop_letterbox,
12
+ preprocess_for_classifier,
13
+ FRUIT_CLASSES
14
+ )
15
  import numpy as np
16
  from PIL import Image
17
  import io
18
 
19
  app = FastAPI()
20
 
21
+ # Загрузка моделей один раз при старте
22
+ model1 = load_model1() # сегментация → weights/model1.pth
23
+ model2 = load_model2() # классификатор → weights/model2.pth
24
+
25
+ DEVICE = torch.device('cpu')
26
 
27
 
28
  @app.get("/")
 
32
 
33
  @app.post("/predict1")
34
  async def predict1(file: UploadFile = File(...)):
35
+ """
36
+ Сегментация фрукта → возвращает маску 256×256 в base64
37
+ """
38
  content = await file.read()
39
  image = Image.open(io.BytesIO(content)).convert('RGB')
40
  image_np = np.array(image)
41
 
42
+ input_tensor = preprocess_image(image_np).unsqueeze(0).to(DEVICE)
43
+
44
  with torch.no_grad():
45
+ logits = model1(input_tensor)
46
+
47
+ pred_mask = postprocess_mask(logits) # shape (256, 256), float [0,1]
48
+
49
+ # Возвращаем только одну маску — 256×256
50
+ return {
51
+ "mask_256_base64": mask_to_base64(pred_mask)
52
+ }
53
+
54
 
55
+ @app.post("/predict2")
56
+ async def predict2(
57
+ file: UploadFile = File(...), # оригинальное изображение (любого размера)
58
+ mask_256_base64: str = Form(...) # маска 256×256 от /predict1
59
+ ):
60
+ """
61
+ Классификация фрукта:
62
+ - ресайз оригинала → letterbox 256×256
63
+ - применение маски
64
+ - crop по bounding box + margin
65
+ - ресайз результата до 100×100
66
+ - inference MobileNetV2
67
+ """
68
+ # 1. Оригинальное изображение
69
+ content = await file.read()
70
+ original_pil = Image.open(io.BytesIO(content)).convert('RGB')
71
+ original_np = np.array(original_pil)
72
+
73
+ # 2. Декодируем маску 256×256
74
+ mask_256 = decode_base64_mask(mask_256_base64)
75
+
76
+ # 3. Letterbox + маска + crop + resize до 100×100
77
+ cropped_100 = apply_mask_and_crop_letterbox(original_np, mask_256)
78
+
79
+ # 4. Препроцессинг для классификатора
80
+ input_tensor = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
81
+
82
+ # 5. Инференс
83
+ with torch.no_grad():
84
+ logits = model2(input_tensor)
85
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
86
 
87
+ pred_idx = int(np.argmax(probs))
88
+ confidence = float(probs[pred_idx])
89
 
90
  return {
91
+ "predicted_fruit": FRUIT_CLASSES[pred_idx],
92
+ "confidence": round(confidence, 4),
93
+ "class_index": pred_idx
94
  }
models.py CHANGED
@@ -1,13 +1,16 @@
1
  import torch
2
- import segmentation_models_pytorch as smp
 
3
 
4
  DEVICE = torch.device('cpu')
5
 
6
- model1 = None
 
7
 
8
  def load_model1(weights_path='weights/seg.pth'):
9
  global model1
10
  if model1 is None:
 
11
  model1 = smp.Unet(
12
  encoder_name="mobilenet_v2",
13
  encoder_weights=None,
@@ -17,4 +20,20 @@ def load_model1(weights_path='weights/seg.pth'):
17
  state_dict = torch.load(weights_path, map_location=DEVICE)
18
  model1.load_state_dict(state_dict)
19
  model1.eval()
20
- return model1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torchvision.models as models
3
+ import torch.nn as nn
4
 
5
  DEVICE = torch.device('cpu')
6
 
7
+ model1 = None # сегментация (как раньше)
8
+ model2 = None # классификатор
9
 
10
  def load_model1(weights_path='weights/seg.pth'):
11
  global model1
12
  if model1 is None:
13
+ import segmentation_models_pytorch as smp
14
  model1 = smp.Unet(
15
  encoder_name="mobilenet_v2",
16
  encoder_weights=None,
 
20
  state_dict = torch.load(weights_path, map_location=DEVICE)
21
  model1.load_state_dict(state_dict)
22
  model1.eval()
23
+ return model1
24
+
25
+
26
+ def load_model2(weights_path='weights/class1.pth'):
27
+ global model2
28
+ if model2 is None:
29
+ model2 = models.mobilenet_v2(pretrained=False)
30
+ # Замораживаем features (как в обучении)
31
+ for param in model2.features.parameters():
32
+ param.requires_grad = False
33
+ # Заменяем classifier
34
+ model2.classifier[1] = nn.Linear(model2.classifier[1].in_features, 10)
35
+ state_dict = torch.load(weights_path, map_location=DEVICE)
36
+ model2.load_state_dict(state_dict)
37
+ model2.to(DEVICE)
38
+ model2.eval()
39
+ return model2
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  fastapi
2
  uvicorn[standard]
3
  torch
 
4
  segmentation_models_pytorch
5
  albumentations
6
  pillow
7
  numpy
8
- opencv-python-headless
 
 
1
  fastapi
2
  uvicorn[standard]
3
  torch
4
+ torchvision
5
  segmentation_models_pytorch
6
  albumentations
7
  pillow
8
  numpy
9
+ opencv-python-headless
10
+ python-multipart
utils.py CHANGED
@@ -6,6 +6,7 @@ import cv2
6
  from PIL import Image
7
  import io
8
  import base64
 
9
 
10
  # Препроцессинг: аналог валидации
11
  preprocess_transform = A.Compose([
@@ -23,14 +24,110 @@ def postprocess_mask(logits: torch.Tensor, threshold: float = 0.5) -> np.ndarray
23
  binary_mask = (pred > threshold).astype(np.float32)
24
  return binary_mask # shape (256, 256)
25
 
26
- def resize_mask(mask: np.ndarray, size: int) -> np.ndarray:
27
- # Resize с nearest neighbor для бинарных масок
28
- resized = cv2.resize(mask, (size, size), interpolation=cv2.INTER_NEAREST)
29
- return resized.astype(np.float32) # 0/1 float
 
 
 
30
 
31
  def mask_to_base64(mask: np.ndarray) -> str:
32
- # Конверт в PIL grayscale (0/255), save as PNG, base64
33
  pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
34
  buffered = io.BytesIO()
35
  pil_mask.save(buffered, format="PNG")
36
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from PIL import Image
7
  import io
8
  import base64
9
+ from torchvision import transforms
10
 
11
  # Препроцессинг: аналог валидации
12
  preprocess_transform = A.Compose([
 
24
  binary_mask = (pred > threshold).astype(np.float32)
25
  return binary_mask # shape (256, 256)
26
 
27
+ # ────────────────────────────────────────────────
28
+ # Для /predict1 — возвращаем маску 256×256
29
+ # ────────────────────────────────────────────────
30
+
31
+ def resize_mask(mask: np.ndarray, size: int = 256) -> np.ndarray:
32
+ return cv2.resize(mask, (size, size), interpolation=cv2.INTER_NEAREST)
33
+
34
 
35
  def mask_to_base64(mask: np.ndarray) -> str:
 
36
  pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
37
  buffered = io.BytesIO()
38
  pil_mask.save(buffered, format="PNG")
39
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
40
+
41
+
42
+ # ────────────────────────────────────────────────
43
+ # Для /predict2
44
+ # ────────────────────────────────────────────────
45
+
46
+ # Новые для классификации
47
+ FRUIT_CLASSES = ['apple', 'banana', 'orange', 'grape', 'strawberry',
48
+ 'tomato', 'pear', 'peach', 'cherry', 'lemon']
49
+
50
+
51
+ def decode_base64_mask(base64_str: str) -> np.ndarray:
52
+ img_data = base64.b64decode(base64_str)
53
+ pil_img = Image.open(io.BytesIO(img_data)).convert('L')
54
+ mask = np.array(pil_img) / 255.0
55
+ return mask.astype(np.float32) # shape ≈ (256, 256)
56
+
57
+
58
+ def letterbox_resize(img: np.ndarray, target_size: int = 256) -> tuple[np.ndarray, float, tuple[int, int]]:
59
+ """
60
+ Resize с сохранением пропорций + padding чёрным
61
+ Возвращает: новое изображение, scale_factor, (pad_top, pad_bottom, pad_left, pad_right)
62
+ """
63
+ h, w = img.shape[:2]
64
+ scale = min(target_size / h, target_size / w)
65
+ new_h, new_w = int(h * scale), int(w * scale)
66
+
67
+ resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
68
+
69
+ pad_h = target_size - new_h
70
+ pad_w = target_size - new_w
71
+ top = pad_h // 2
72
+ bottom = pad_h - top
73
+ left = pad_w // 2
74
+ right = pad_w - left
75
+
76
+ padded = cv2.copyMakeBorder(
77
+ resized, top, bottom, left, right,
78
+ cv2.BORDER_CONSTANT, value=(0, 0, 0)
79
+ )
80
+
81
+ return padded, scale, (top, bottom, left, right)
82
+
83
+
84
+ def apply_mask_and_crop_letterbox(
85
+ orig_img: np.ndarray, # оригинал любой размер
86
+ mask_256: np.ndarray # маска 256×256 [0..1]
87
+ ) -> np.ndarray:
88
+ """
89
+ 1. Делаем letterbox-версию оригинала 256×256
90
+ 2. Применяем маску
91
+ 3. Находим bbox
92
+ 4. Вырезаем + margin
93
+ 5. Ресайзим до 100×100
94
+ """
95
+ letterbox_img, scale, paddings = letterbox_resize(orig_img, 256)
96
+ top, bottom, left, right = paddings
97
+
98
+ # Маска уже 256×256 — применяем напрямую
99
+ masked = letterbox_img.copy()
100
+ masked[mask_256 < 0.5] = 0
101
+
102
+ # Находим контуры / bbox
103
+ mask_bin = (mask_256 > 0.5).astype(np.uint8) * 255
104
+ contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
105
+
106
+ if not contours:
107
+ return np.zeros((100, 100, 3), dtype=np.uint8)
108
+
109
+ cnt = max(contours, key=cv2.contourArea)
110
+ x, y, bw, bh = cv2.boundingRect(cnt)
111
+
112
+ # margin ~10%
113
+ margin = int(max(bw, bh) * 0.12)
114
+ x1 = max(0, x - margin)
115
+ y1 = max(0, y - margin)
116
+ x2 = min(256, x + bw + margin)
117
+ y2 = min(256, y + bh + margin)
118
+
119
+ cropped = masked[y1:y2, x1:x2]
120
+
121
+ # Финальный ресайз до 100×100 для классификатора
122
+ final = cv2.resize(cropped, (100, 100), interpolation=cv2.INTER_AREA)
123
+
124
+ return final
125
+
126
+
127
+ def preprocess_for_classifier(img_100: np.ndarray) -> torch.Tensor:
128
+ transform = transforms.Compose([
129
+ transforms.ToPILImage(),
130
+ transforms.ToTensor(),
131
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
132
+ ])
133
+ return transform(img_100)