Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoImageProcessor, ResNetForImageClassification | |
| from PIL import Image | |
| import torch | |
| import io | |
| # ํ ๋งํ ์ ๋ณํด ์ ์ฉ ๋ชจ๋ธ | |
| MODEL_ID = "wellCh4n/tomato-leaf-disease-classification-resnet50" | |
| app = FastAPI(title="SmartFarm Tomato Disease API") | |
| # --- ๋ชจ๋ธ & ์ ์ฒ๋ฆฌ๊ธฐ ๋ก๋ฉ (์๋ฒ ์์ ์ 1๋ฒ๋ง) --- | |
| processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
| model = ResNetForImageClassification.from_pretrained(MODEL_ID) | |
| model.eval() # ์ถ๋ก ๋ชจ๋ | |
| # --- id2label ๊ฐ์ ์ค๋ฒ๋ผ์ด๋ (Unknown_* ๋ฐฉ์ง์ฉ) --- | |
| # ๋ชจ๋ธ config ์์ id2label์ด ์ด์ํ๋ฉด ์ฐ๋ฆฌ๊ฐ ์ง์ ์ง์ ํ๋ค. | |
| custom_id2label = { | |
| 0: "Tomato_healthy", | |
| 1: "Tomato_Bacterial_spot", | |
| 2: "Tomato_Early_blight", | |
| 3: "Tomato_Late_blight", | |
| 4: "Tomato_Leaf_Mold", | |
| 5: "Tomato_Septoria_leaf_spot", | |
| 6: "Tomato_Spider_mites_Two_spotted_spider_mite", | |
| 7: "Tomato_Target_Spot", | |
| 8: "Tomato_Tomato_Yellow_Leaf_Curl_Virus", | |
| 9: "Tomato_Tomato_mosaic_virus", | |
| } | |
| # ๋ชจ๋ธ config ์๋ ๋ฐ์ (ํน์ ๋ด๋ถ์์ ์ฐธ์กฐํ ์๋ ์์ผ๋๊น) | |
| model.config.id2label = {int(k): v for k, v in custom_id2label.items()} | |
| model.config.label2id = {v: int(k) for k, v in custom_id2label.items()} | |
| def infer_image(img_bytes: bytes, topk: int = 5): | |
| """ | |
| ์ด๋ฏธ์ง ๋ฐ์ดํธ -> topk [{label, score}, ...] ๋ฐํ | |
| score๋ 0.0~1.0 ์ฌ์ด float | |
| """ | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] | |
| values, indices = probs.topk(topk) | |
| values = values.tolist() | |
| indices = indices.tolist() | |
| id2label = model.config.id2label | |
| results = [] | |
| for score, idx in zip(values, indices): | |
| label = id2label.get(int(idx), f"Unknown_{idx}") | |
| results.append({ | |
| "label": label, | |
| "score": float(score), | |
| }) | |
| return results | |
| async def predict(file: UploadFile = File(...)): | |
| """ | |
| PHP์์ ๋ณด๋ด๋ ์ด๋ฏธ์ง ํ์ผ ํ๋๋ฅผ ๋ฐ์์ | |
| HF Inference API์ ๋น์ทํ ํ์์ผ๋ก ๊ฒฐ๊ณผ ๋ฐํ: | |
| [ | |
| {"label": "...", "score": 0.87}, | |
| {"label": "...", "score": 0.05}, | |
| ... | |
| ] | |
| """ | |
| try: | |
| img_bytes = await file.read() | |
| if not img_bytes: | |
| return JSONResponse( | |
| {"error": True, "message": "Empty file"}, | |
| status_code=400, | |
| ) | |
| raw = infer_image(img_bytes, topk=5) | |
| return JSONResponse(raw, status_code=200) | |
| except Exception as e: | |
| # ์๋ฌ ๋๋ฉด PHP์์ ๋ฉ์์ง ํ์ธํ๊ธฐ ์ฝ๋๋ก ๋ฌธ์์ด๋ก ๋ด๋ ค์ค | |
| return JSONResponse( | |
| {"error": True, "message": str(e)}, | |
| status_code=500, | |
| ) | |