ivanm151 commited on
Commit
94c643d
·
1 Parent(s): 0d636fb

one endpoint

Browse files
Files changed (1) hide show
  1. app.py +82 -98
app.py CHANGED
@@ -1,13 +1,10 @@
1
- from fastapi import FastAPI, UploadFile, File, Form
2
  import torch
3
  from models import load_model1, load_model2, load_model3
4
  from utils import (
5
- # Для /predict1
6
  preprocess_image,
7
  postprocess_mask,
8
  mask_to_base64,
9
- # Для /predict2
10
- decode_base64_mask,
11
  apply_mask_and_crop_letterbox,
12
  preprocess_for_classifier,
13
  FRUIT_CLASSES,
@@ -20,119 +17,106 @@ import io
20
 
21
  app = FastAPI()
22
 
23
- # Загрузка моделей один раз при старте
24
- model1 = load_model1() # сегментация → weights/model1.pth
25
- model2 = load_model2() # классификатор → weights/model2.pth
26
- model3 = load_model3()
27
 
28
  DEVICE = torch.device('cpu')
29
 
 
 
 
30
 
31
  @app.get("/")
32
  def greet_json():
33
  return {"Hello": "World!"}
34
 
35
 
36
- @app.post("/predict1")
37
- async def predict1(file: UploadFile = File(...)):
 
 
 
38
  """
39
- Сегментация фрукта → возвращает маску 256×256 в base64
 
 
 
40
  """
 
41
  content = await file.read()
42
  image = Image.open(io.BytesIO(content)).convert('RGB')
43
- image_np = np.array(image)
44
-
45
- input_tensor = preprocess_image(image_np).unsqueeze(0).to(DEVICE)
46
 
 
 
47
  with torch.no_grad():
48
  logits = model1(input_tensor)
49
-
50
- pred_mask = postprocess_mask(logits) # shape (256, 256), float [0,1]
51
-
52
- # Возвращаем только одну маску — 256×256
53
- return {
54
- "mask_256_base64": mask_to_base64(pred_mask)
55
- }
56
-
57
-
58
- @app.post("/predict2")
59
- async def predict2(
60
- file: UploadFile = File(...), # оригинальное изображение (любого размера)
61
- mask_256_base64: str = Form(...) # маска 256×256 от /predict1
62
- ):
63
- """
64
- Классификация фрукта:
65
- - ресайз оригинала → letterbox 256×256
66
- - применение маски
67
- - crop по bounding box + margin
68
- - ресайз результата до 100×100
69
- - inference MobileNetV2
70
- """
71
- # 1. Оригинальное изображение
72
- content = await file.read()
73
- original_pil = Image.open(io.BytesIO(content)).convert('RGB')
74
- original_np = np.array(original_pil)
75
-
76
- # 2. Декодируем маску 256×256
77
- mask_256 = decode_base64_mask(mask_256_base64)
78
-
79
- # 3. Letterbox + маска + crop + resize до 100×100
80
- cropped_100 = apply_mask_and_crop_letterbox(original_np, mask_256, margin_ratio=0.02,
81
- target_size=100,
82
- bg_color=(255, 255, 255))
83
-
84
- # 4. Препроцессинг для классификатора
85
- input_tensor = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
86
-
87
- # 5. Инференс
88
- with torch.no_grad():
89
- logits = model2(input_tensor)
90
- probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
91
-
92
- pred_idx = int(np.argmax(probs))
93
- confidence = float(probs[pred_idx])
94
-
95
- return {
96
- "predicted_fruit": FRUIT_CLASSES[pred_idx],
97
- "confidence": round(confidence, 4),
98
- "class_index": pred_idx
99
- }
100
-
101
-
102
- @app.post("/predict3")
103
- async def predict3(
104
- file: UploadFile = File(...), # оригинальное изображение
105
- mask_256_base64: str = Form(...) # та же маска 256×256 от /predict1
106
- ):
107
- # Оригинал
108
- content = await file.read()
109
- original_pil = Image.open(io.BytesIO(content)).convert('RGB')
110
- original_np = np.array(original_pil)
111
-
112
- # Маска
113
- mask_256 = decode_base64_mask(mask_256_base64)
114
-
115
- # Вырезаем и готовим 224×224
116
- cropped_224 = apply_mask_and_crop_letterbox(
117
- original_np,
118
  mask_256,
119
- margin_ratio=0.05,
120
- target_size=224,
121
  bg_color=(255, 255, 255)
122
  )
123
 
124
- # Preprocess + inference
125
- input_tensor = preprocess_for_freshness(cropped_224).unsqueeze(0).to(DEVICE)
126
-
127
  with torch.no_grad():
128
- logits = model3(input_tensor)
129
- probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
130
-
131
- pred_idx = int(np.argmax(probs))
132
- confidence = float(probs[pred_idx])
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- return {
135
- "predicted_class": FRESHNESS_CLASSES[pred_idx],
136
- "confidence": round(confidence, 4),
137
- "class_index": pred_idx
138
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Query
2
  import torch
3
  from models import load_model1, load_model2, load_model3
4
  from utils import (
 
5
  preprocess_image,
6
  postprocess_mask,
7
  mask_to_base64,
 
 
8
  apply_mask_and_crop_letterbox,
9
  preprocess_for_classifier,
10
  FRUIT_CLASSES,
 
17
 
18
  app = FastAPI()
19
 
20
+ # Глобальная загрузка моделей
21
+ model1 = load_model1() # segmentation
22
+ model2 = load_model2() # fruit type (10 классов)
23
+ model3 = load_model3() # freshness (6 классов)
24
 
25
  DEVICE = torch.device('cpu')
26
 
27
+ # Классы, для которых делаем свежесть
28
+ FRESHNESS_ELIGIBLE = {'apple', 'banana', 'orange'}
29
+
30
 
31
  @app.get("/")
32
  def greet_json():
33
  return {"Hello": "World!"}
34
 
35
 
36
+ @app.post("/predict_full")
37
+ async def predict_full(
38
+ file: UploadFile = File(...),
39
+ return_mask: bool = Query(default=False, description="Вернуть base64 маску сегментации?")
40
+ ):
41
  """
42
+ Полный пайплайн:
43
+ 1. Сегментация → маска
44
+ 2. Если фрукт найден → классификация сорта
45
+ 3. Если сорт в ['apple', 'banana', 'orange'] → классификация свежести
46
  """
47
+ # 1. Чтение изображения
48
  content = await file.read()
49
  image = Image.open(io.BytesIO(content)).convert('RGB')
50
+ orig_np = np.array(image)
 
 
51
 
52
+ # 2. Сегментация
53
+ input_tensor = preprocess_image(orig_np).unsqueeze(0).to(DEVICE)
54
  with torch.no_grad():
55
  logits = model1(input_tensor)
56
+ mask_256 = postprocess_mask(logits) # (256, 256) float [0,1]
57
+
58
+ # Проверка: есть ли фрукт? (площадь > 5%)
59
+ fruit_area_ratio = np.mean(mask_256 > 0.5)
60
+ if fruit_area_ratio < 0.05:
61
+ return {
62
+ "status": "no_fruit_detected",
63
+ "fruit_area_ratio": round(fruit_area_ratio, 4),
64
+ "fruit": None,
65
+ "fruit_confidence": None,
66
+ "freshness": None,
67
+ "freshness_confidence": None,
68
+ "mask_256_base64": mask_to_base64(mask_256) if return_mask else None
69
+ }
70
+
71
+ # 3. Обрезание под модель сорта (100×100)
72
+ cropped_100 = apply_mask_and_crop_letterbox(
73
+ orig_np,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  mask_256,
75
+ margin_ratio=0.02,
76
+ target_size=100,
77
  bg_color=(255, 255, 255)
78
  )
79
 
80
+ # 4. Классификация сорта
81
+ input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
 
82
  with torch.no_grad():
83
+ logits2 = model2(input_tensor2)
84
+ probs2 = torch.softmax(logits2, dim=1).squeeze().cpu().numpy()
85
+
86
+ fruit_idx = int(np.argmax(probs2))
87
+ fruit_name = FRUIT_CLASSES[fruit_idx]
88
+ fruit_conf = float(probs2[fruit_idx])
89
+
90
+ result = {
91
+ "status": "success",
92
+ "fruit_area_ratio": round(fruit_area_ratio, 4),
93
+ "fruit": fruit_name,
94
+ "fruit_confidence": round(fruit_conf, 4),
95
+ "freshness": None,
96
+ "freshness_confidence": None,
97
+ "mask_256_base64": mask_to_base64(mask_256) if return_mask else None
98
+ }
99
 
100
+ # 5. Если фрукт подходит — делаем свежесть
101
+ if fruit_name in FRESHNESS_ELIGIBLE:
102
+ cropped_224 = apply_mask_and_crop_letterbox(
103
+ orig_np,
104
+ mask_256,
105
+ margin_ratio=0.05,
106
+ target_size=224,
107
+ bg_color=(255, 255, 255)
108
+ )
109
+
110
+ input_tensor3 = preprocess_for_freshness(cropped_224).unsqueeze(0).to(DEVICE)
111
+ with torch.no_grad():
112
+ logits3 = model3(input_tensor3)
113
+ probs3 = torch.softmax(logits3, dim=1).squeeze().cpu().numpy()
114
+
115
+ fresh_idx = int(np.argmax(probs3))
116
+ fresh_name = FRESHNESS_CLASSES[fresh_idx]
117
+ fresh_conf = float(probs3[fresh_idx])
118
+
119
+ result["freshness"] = fresh_name
120
+ result["freshness_confidence"] = round(fresh_conf, 4)
121
+
122
+ return result