File size: 4,252 Bytes
94c643d 3dc4dee f8ed5ba b1e9f50 1c6f885 2760b2b 1c6f885 9672426 82ab87f 1c6f885 3dc4dee f8ed5ba fbb0759 1c6f885 3dc4dee 611306f 94c643d 3dc4dee 12852aa 3dc4dee 94c643d 4f559e9 f8ed5ba b1e9f50 94c643d 3dc4dee 94c643d 1c6f885 fbb0759 f8ed5ba fbb0759 f8ed5ba b1e9f50 f8ed5ba d3fc523 94c643d 12852aa 94c643d f8ed5ba 94c643d f8ed5ba 2760b2b 94c643d 12852aa 9672426 94c643d 12852aa 94c643d 12852aa 94c643d f8ed5ba 94c643d 9672426 12852aa 777d510 94c643d 4f559e9 f8ed5ba 2760b2b f8ed5ba 94c643d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from fastapi import FastAPI, UploadFile, File, Query
import torch
import numpy as np
from PIL import Image
import io
import base64
from models import load_sam, load_model2, load_model3
from utils import (
crop_fruit_contour_letterbox,
preprocess_for_classifier,
FRUIT_CLASSES,
FRESHNESS_CLASSES
)
app = FastAPI()
# Загрузка моделей
sam_predictor = load_sam() # MobileSAM + SamPredictor
model2 = load_model2()
model3 = load_model3()
DEVICE = torch.device('cpu')
FRESHNESS_ELIGIBLE = {'apple', 'banana', 'orange', 'lemon'}
@app.get("/")
def greet_json():
return {"swagger https://ivanm151-fruits.hf.space/docs#"}
@app.post("/predict_full")
async def predict_full(
file: UploadFile = File(...),
point_x: int = Query(..., description="X-координата точки на фрукте"),
point_y: int = Query(..., description="Y-координата точки на фрукте"),
return_cropped: bool = Query(default=True, description="Вернуть обрезанное изображение в base64?"),
cropped_size: int = Query(224, description="Размер обрезанного изображения (100 или 224)")
):
content = await file.read()
image = Image.open(io.BytesIO(content)).convert('RGB')
orig_np = np.array(image)
# Установка изображения в MobileSAM Predictor
sam_predictor.set_image(orig_np)
# Промпт: точка на фрукте
input_point = np.array([[point_x, point_y]])
input_label = np.array([1]) # 1 = foreground
# Предсказание маски
masks, scores, logits = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False
)
# Берём лучшую маску
best_mask_idx = np.argmax(scores)
mask = masks[best_mask_idx] # bool
# Проверка: есть ли фрукт?
fruit_area_ratio = np.mean(mask)
if fruit_area_ratio < 0.01:
return {
"fruit_top3": [],
"freshness": None,
"freshness_confidence": None,
"cropped_base64": None
}
# Обрезка под 100×100 для сорта
cropped_100 = crop_fruit_contour_letterbox(orig_np, mask, out_size=100)
input_tensor2 = preprocess_for_classifier(cropped_100).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits2 = model2(input_tensor2)
probs2 = torch.softmax(logits2, dim=1).squeeze().cpu().numpy()
# ТОП-3 фрукта
top3_indices = np.argsort(probs2)[-3:][::-1] # индексы от самого уверенного
top3 = [
{
"fruit": FRUIT_CLASSES[idx],
"confidence": round(float(probs2[idx]), 4)
}
for idx in top3_indices
]
# Проверяем, есть ли хотя бы один фрукт из FRESHNESS_ELIGIBLE в топ-3
eligible_in_top3 = any(item["fruit"] in FRESHNESS_ELIGIBLE for item in top3)
result = {
"fruit_top3": top3,
"freshness": None,
"freshness_confidence": None,
"cropped_base64": None
}
# Свежесть, если есть eligible фрукт в топ-3
if eligible_in_top3:
cropped_224 = crop_fruit_contour_letterbox(orig_np, mask, out_size=100)
input_tensor3 = preprocess_for_classifier(cropped_224).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits3 = model3(input_tensor3)
probs3 = torch.softmax(logits3, dim=1).squeeze().cpu().numpy()
fresh_idx = int(np.argmax(probs3))
fresh_name = FRESHNESS_CLASSES[fresh_idx]
fresh_conf = float(probs3[fresh_idx])
result["freshness"] = fresh_name
result["freshness_confidence"] = round(fresh_conf, 4)
# Возвращаем обрезанное изображение
if return_cropped:
cropped_final = crop_fruit_contour_letterbox(orig_np, mask, out_size=cropped_size)
pil_img = Image.fromarray(cropped_final)
buffered = io.BytesIO()
pil_img.save(buffered, format="PNG")
result["cropped_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
return result |