ivanm151 commited on
Commit
9672426
·
1 Parent(s): acddc4b

3rd endpoint

Browse files
Files changed (3) hide show
  1. app.py +44 -2
  2. models.py +17 -2
  3. utils.py +62 -1
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, UploadFile, File, Form
2
  import torch
3
- from models import load_model1, load_model2
4
  from utils import (
5
  # Для /predict1
6
  preprocess_image,
@@ -10,7 +10,10 @@ from utils import (
10
  decode_base64_mask,
11
  apply_mask_and_crop_letterbox,
12
  preprocess_for_classifier,
13
- FRUIT_CLASSES
 
 
 
14
  )
15
  import numpy as np
16
  from PIL import Image
@@ -21,6 +24,7 @@ app = FastAPI()
21
  # Загрузка моделей один раз при старте
22
  model1 = load_model1() # сегментация → weights/model1.pth
23
  model2 = load_model2() # классификатор → weights/model2.pth
 
24
 
25
  DEVICE = torch.device('cpu')
26
 
@@ -91,4 +95,42 @@ async def predict2(
91
  "predicted_fruit": FRUIT_CLASSES[pred_idx],
92
  "confidence": round(confidence, 4),
93
  "class_index": pred_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  }
 
1
  from fastapi import FastAPI, UploadFile, File, Form
2
  import torch
3
+ from models import load_model1, load_model2, load_model3
4
  from utils import (
5
  # Для /predict1
6
  preprocess_image,
 
10
  decode_base64_mask,
11
  apply_mask_and_crop_letterbox,
12
  preprocess_for_classifier,
13
+ FRUIT_CLASSES,
14
+ apply_mask_and_crop_letterbox_224,
15
+ preprocess_for_freshness,
16
+ FRESHNESS_CLASSES
17
  )
18
  import numpy as np
19
  from PIL import Image
 
24
  # Загрузка моделей один раз при старте
25
  model1 = load_model1() # сегментация → weights/model1.pth
26
  model2 = load_model2() # классификатор → weights/model2.pth
27
+ model3 = load_model3()
28
 
29
  DEVICE = torch.device('cpu')
30
 
 
95
  "predicted_fruit": FRUIT_CLASSES[pred_idx],
96
  "confidence": round(confidence, 4),
97
  "class_index": pred_idx
98
+ }
99
+
100
+
101
+ @app.post("/predict3")
102
+ async def predict3(
103
+ file: UploadFile = File(...), # оригинальное изображение
104
+ mask_256_base64: str = Form(...) # та же маска 256×256 от /predict1
105
+ ):
106
+ # Оригинал
107
+ content = await file.read()
108
+ original_pil = Image.open(io.BytesIO(content)).convert('RGB')
109
+ original_np = np.array(original_pil)
110
+
111
+ # Маска
112
+ mask_256 = decode_base64_mask(mask_256_base64)
113
+
114
+ # Вырезаем и готовим 224×224
115
+ cropped_224 = apply_mask_and_crop_letterbox_224(
116
+ original_np,
117
+ mask_256,
118
+ margin_ratio=0.05, # подбери под свои тесты
119
+ bg_color=(255, 255, 255)
120
+ )
121
+
122
+ # Preprocess + inference
123
+ input_tensor = preprocess_for_freshness(cropped_224).unsqueeze(0).to(DEVICE)
124
+
125
+ with torch.no_grad():
126
+ logits = model3(input_tensor)
127
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
128
+
129
+ pred_idx = int(np.argmax(probs))
130
+ confidence = float(probs[pred_idx])
131
+
132
+ return {
133
+ "predicted_class": FRESHNESS_CLASSES[pred_idx],
134
+ "confidence": round(confidence, 4),
135
+ "class_index": pred_idx
136
  }
models.py CHANGED
@@ -5,7 +5,8 @@ import torch.nn as nn
5
  DEVICE = torch.device('cpu')
6
 
7
  model1 = None # сегментация (как раньше)
8
- model2 = None # классификатор
 
9
 
10
  def load_model1(weights_path='weights/seg.pth'):
11
  global model1
@@ -36,4 +37,18 @@ def load_model2(weights_path='weights/class1.pth'):
36
  model2.load_state_dict(state_dict)
37
  model2.to(DEVICE)
38
  model2.eval()
39
- return model2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  DEVICE = torch.device('cpu')
6
 
7
  model1 = None # сегментация (как раньше)
8
+ model2 = None # классификатор 1
9
+ model3 = None # классификатор 2
10
 
11
  def load_model1(weights_path='weights/seg.pth'):
12
  global model1
 
37
  model2.load_state_dict(state_dict)
38
  model2.to(DEVICE)
39
  model2.eval()
40
+ return model2
41
+
42
+
43
+ def load_model3(weights_path='weights/class2.pth'):
44
+ global model3
45
+ if model3 is None:
46
+ model3 = models.mobilenet_v2(pretrained=False)
47
+ for param in model3.features.parameters():
48
+ param.requires_grad = False
49
+ model3.classifier[1] = nn.Linear(model3.classifier[1].in_features, 6)
50
+ state_dict = torch.load(weights_path, map_location=DEVICE)
51
+ model3.load_state_dict(state_dict)
52
+ model3.to(DEVICE)
53
+ model3.eval()
54
+ return model3
utils.py CHANGED
@@ -130,4 +130,65 @@ def preprocess_for_classifier(img_100: np.ndarray) -> torch.Tensor:
130
  transforms.ToTensor(),
131
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
132
  ])
133
- return transform(img_100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  transforms.ToTensor(),
131
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
132
  ])
133
+ return transform(img_100)
134
+
135
+
136
+ # ... весь предыдущий код остаётся ...
137
+
138
+ # Новые константы для модели свежести
139
+ FRESHNESS_CLASSES = [
140
+ 'freshapples', 'freshbanana', 'freshoranges',
141
+ 'rottenapples', 'rottenbanana', 'rottenoranges'
142
+ ]
143
+
144
+ def preprocess_for_freshness(img_224: np.ndarray) -> torch.Tensor:
145
+ """ Трансформации, аналогичные test_transforms из обучения """
146
+ transform = transforms.Compose([
147
+ transforms.ToPILImage(),
148
+ transforms.ToTensor(),
149
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
150
+ ])
151
+ return transform(img_224)
152
+
153
+
154
+ def apply_mask_and_crop_letterbox_224(
155
+ orig_img: np.ndarray,
156
+ mask_256: np.ndarray,
157
+ margin_ratio: float = 0.05, # можно подкрутить
158
+ bg_color: tuple = (255, 255, 255) # белый фон — важно!
159
+ ) -> np.ndarray:
160
+ """
161
+ Аналог apply_mask_and_crop_letterbox, но для 224×224
162
+ """
163
+ # Letterbox до 224×224
164
+ letterbox_img, scale, paddings = letterbox_resize(orig_img, target_size=224)
165
+ top, bottom, left, right = paddings
166
+
167
+ # Применяем маску (маска 256→ресайзим до 224)
168
+ mask_resized = cv2.resize(mask_256, (224, 224), interpolation=cv2.INTER_NEAREST)
169
+
170
+ masked = letterbox_img.copy()
171
+ masked[mask_resized < 0.5] = bg_color # белый фон
172
+
173
+ # Контуры
174
+ mask_bin = (mask_resized > 0.5).astype(np.uint8) * 255
175
+ contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
176
+
177
+ if not contours:
178
+ return np.full((224, 224, 3), bg_color, dtype=np.uint8)
179
+
180
+ cnt = max(contours, key=cv2.contourArea)
181
+ x, y, bw, bh = cv2.boundingRect(cnt)
182
+
183
+ margin = int(max(bw, bh) * margin_ratio)
184
+ x1 = max(0, x - margin)
185
+ y1 = max(0, y - margin)
186
+ x2 = min(224, x + bw + margin)
187
+ y2 = min(224, y + bh + margin)
188
+
189
+ cropped = masked[y1:y2, x1:x2]
190
+
191
+ # Финальный resize до 224×224 (если обрезали меньше)
192
+ final = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA)
193
+
194
+ return final