new response
Browse files
app.py
CHANGED
|
@@ -25,7 +25,7 @@ FRESHNESS_ELIGIBLE = {'apple', 'banana', 'orange', 'lemon'}
|
|
| 25 |
|
| 26 |
@app.get("/")
|
| 27 |
def greet_json():
|
| 28 |
-
return {"
|
| 29 |
|
| 30 |
@app.post("/predict_full")
|
| 31 |
async def predict_full(
|
|
@@ -61,10 +61,7 @@ async def predict_full(
|
|
| 61 |
fruit_area_ratio = np.mean(mask)
|
| 62 |
if fruit_area_ratio < 0.01:
|
| 63 |
return {
|
| 64 |
-
"
|
| 65 |
-
"fruit_area_ratio": round(fruit_area_ratio, 4),
|
| 66 |
-
"fruit": None,
|
| 67 |
-
"fruit_confidence": None,
|
| 68 |
"freshness": None,
|
| 69 |
"freshness_confidence": None,
|
| 70 |
"cropped_base64": None
|
|
@@ -73,27 +70,34 @@ async def predict_full(
|
|
| 73 |
# Обрезка под 100×100 для сорта
|
| 74 |
cropped_100 = crop_fruit_contour_letterbox(orig_np, mask, out_size=100)
|
| 75 |
input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
|
|
|
|
| 76 |
with torch.no_grad():
|
| 77 |
logits2 = model2(input_tensor2)
|
| 78 |
probs2 = torch.softmax(logits2, dim=1).squeeze().cpu().numpy()
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
result = {
|
| 85 |
-
"
|
| 86 |
-
"fruit_area_ratio": round(fruit_area_ratio, 4),
|
| 87 |
-
"fruit": fruit_name,
|
| 88 |
-
"fruit_confidence": round(fruit_conf, 4),
|
| 89 |
"freshness": None,
|
| 90 |
"freshness_confidence": None,
|
| 91 |
"cropped_base64": None
|
| 92 |
}
|
| 93 |
|
| 94 |
-
# Свежесть, если
|
| 95 |
-
if
|
| 96 |
-
cropped_224 = crop_fruit_contour_letterbox(orig_np, mask, out_size=
|
| 97 |
input_tensor3 = preprocess_for_classifier(cropped_224).unsqueeze(0).to(DEVICE)
|
| 98 |
with torch.no_grad():
|
| 99 |
logits3 = model3(input_tensor3)
|
|
@@ -113,6 +117,5 @@ async def predict_full(
|
|
| 113 |
buffered = io.BytesIO()
|
| 114 |
pil_img.save(buffered, format="PNG")
|
| 115 |
result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 116 |
-
result["cropped_size"] = f"{cropped_size}x{cropped_size}"
|
| 117 |
|
| 118 |
return result
|
|
|
|
| 25 |
|
| 26 |
@app.get("/")
|
| 27 |
def greet_json():
|
| 28 |
+
return {"swagger https://ivanm151-fruits.hf.space/docs#"}
|
| 29 |
|
| 30 |
@app.post("/predict_full")
|
| 31 |
async def predict_full(
|
|
|
|
| 61 |
fruit_area_ratio = np.mean(mask)
|
| 62 |
if fruit_area_ratio < 0.01:
|
| 63 |
return {
|
| 64 |
+
"fruit_top3": [],
|
|
|
|
|
|
|
|
|
|
| 65 |
"freshness": None,
|
| 66 |
"freshness_confidence": None,
|
| 67 |
"cropped_base64": None
|
|
|
|
| 70 |
# Обрезка под 100×100 для сорта
|
| 71 |
cropped_100 = crop_fruit_contour_letterbox(orig_np, mask, out_size=100)
|
| 72 |
input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
|
| 73 |
+
|
| 74 |
with torch.no_grad():
|
| 75 |
logits2 = model2(input_tensor2)
|
| 76 |
probs2 = torch.softmax(logits2, dim=1).squeeze().cpu().numpy()
|
| 77 |
|
| 78 |
+
# ТОП-3 фрукта
|
| 79 |
+
top3_indices = np.argsort(probs2)[-3:][::-1] # индексы от самого уверенного
|
| 80 |
+
top3 = [
|
| 81 |
+
{
|
| 82 |
+
"fruit": FRUIT_CLASSES[idx],
|
| 83 |
+
"confidence": round(float(probs2[idx]), 4)
|
| 84 |
+
}
|
| 85 |
+
for idx in top3_indices
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
# Проверяем, есть ли хотя бы один фрукт из FRESHNESS_ELIGIBLE в топ-3
|
| 89 |
+
eligible_in_top3 = any(item["fruit"] in FRESHNESS_ELIGIBLE for item in top3)
|
| 90 |
|
| 91 |
result = {
|
| 92 |
+
"fruit_top3": top3,
|
|
|
|
|
|
|
|
|
|
| 93 |
"freshness": None,
|
| 94 |
"freshness_confidence": None,
|
| 95 |
"cropped_base64": None
|
| 96 |
}
|
| 97 |
|
| 98 |
+
# Свежесть, если есть eligible фрукт в топ-3
|
| 99 |
+
if eligible_in_top3:
|
| 100 |
+
cropped_224 = crop_fruit_contour_letterbox(orig_np, mask, out_size=100)
|
| 101 |
input_tensor3 = preprocess_for_classifier(cropped_224).unsqueeze(0).to(DEVICE)
|
| 102 |
with torch.no_grad():
|
| 103 |
logits3 = model3(input_tensor3)
|
|
|
|
| 117 |
buffered = io.BytesIO()
|
| 118 |
pil_img.save(buffered, format="PNG")
|
| 119 |
result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
|
|
| 120 |
|
| 121 |
return result
|