Bread45879's picture
Update app.py
e771935 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from transformers import AutoImageProcessor, ResNetForImageClassification
from PIL import Image
import torch
import io
# ํ† ๋งˆํ†  ์žŽ ๋ณ‘ํ•ด ์ „์šฉ ๋ชจ๋ธ
MODEL_ID = "wellCh4n/tomato-leaf-disease-classification-resnet50"
app = FastAPI(title="SmartFarm Tomato Disease API")
# --- ๋ชจ๋ธ & ์ „์ฒ˜๋ฆฌ๊ธฐ ๋กœ๋”ฉ (์„œ๋ฒ„ ์‹œ์ž‘ ์‹œ 1๋ฒˆ๋งŒ) ---
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = ResNetForImageClassification.from_pretrained(MODEL_ID)
model.eval() # ์ถ”๋ก  ๋ชจ๋“œ
# --- id2label ๊ฐ•์ œ ์˜ค๋ฒ„๋ผ์ด๋“œ (Unknown_* ๋ฐฉ์ง€์šฉ) ---
# ๋ชจ๋ธ config ์•ˆ์˜ id2label์ด ์ด์ƒํ•˜๋ฉด ์šฐ๋ฆฌ๊ฐ€ ์ง์ ‘ ์ง€์ •ํ•œ๋‹ค.
custom_id2label = {
0: "Tomato_healthy",
1: "Tomato_Bacterial_spot",
2: "Tomato_Early_blight",
3: "Tomato_Late_blight",
4: "Tomato_Leaf_Mold",
5: "Tomato_Septoria_leaf_spot",
6: "Tomato_Spider_mites_Two_spotted_spider_mite",
7: "Tomato_Target_Spot",
8: "Tomato_Tomato_Yellow_Leaf_Curl_Virus",
9: "Tomato_Tomato_mosaic_virus",
}
# ๋ชจ๋ธ config ์—๋„ ๋ฐ˜์˜ (ํ˜น์‹œ ๋‚ด๋ถ€์—์„œ ์ฐธ์กฐํ•  ์ˆ˜๋„ ์žˆ์œผ๋‹ˆ๊นŒ)
model.config.id2label = {int(k): v for k, v in custom_id2label.items()}
model.config.label2id = {v: int(k) for k, v in custom_id2label.items()}
@torch.no_grad()
def infer_image(img_bytes: bytes, topk: int = 5):
"""
์ด๋ฏธ์ง€ ๋ฐ”์ดํŠธ -> topk [{label, score}, ...] ๋ฐ˜ํ™˜
score๋Š” 0.0~1.0 ์‚ฌ์ด float
"""
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
values, indices = probs.topk(topk)
values = values.tolist()
indices = indices.tolist()
id2label = model.config.id2label
results = []
for score, idx in zip(values, indices):
label = id2label.get(int(idx), f"Unknown_{idx}")
results.append({
"label": label,
"score": float(score),
})
return results
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
PHP์—์„œ ๋ณด๋‚ด๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ ํ•˜๋‚˜๋ฅผ ๋ฐ›์•„์„œ
HF Inference API์™€ ๋น„์Šทํ•œ ํ˜•์‹์œผ๋กœ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜:
[
{"label": "...", "score": 0.87},
{"label": "...", "score": 0.05},
...
]
"""
try:
img_bytes = await file.read()
if not img_bytes:
return JSONResponse(
{"error": True, "message": "Empty file"},
status_code=400,
)
raw = infer_image(img_bytes, topk=5)
return JSONResponse(raw, status_code=200)
except Exception as e:
# ์—๋Ÿฌ ๋‚˜๋ฉด PHP์—์„œ ๋ฉ”์‹œ์ง€ ํ™•์ธํ•˜๊ธฐ ์‰ฝ๋„๋ก ๋ฌธ์ž์—ด๋กœ ๋‚ด๋ ค์คŒ
return JSONResponse(
{"error": True, "message": str(e)},
status_code=500,
)