Bread45879 commited on
Commit
e771935
·
verified ·
1 Parent(s): 55a5fde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -40
app.py CHANGED
@@ -5,69 +5,87 @@ from PIL import Image
5
  import torch
6
  import io
7
 
8
- # 토마토 잎 병해 전용 모델
9
  MODEL_ID = "wellCh4n/tomato-leaf-disease-classification-resnet50"
10
 
11
  app = FastAPI(title="SmartFarm Tomato Disease API")
12
 
13
- # --- 모델 & 전처리기 로드 (서버 시작 시 1번만) ---
14
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
15
  model = ResNetForImageClassification.from_pretrained(MODEL_ID)
16
  model.eval() # 추론 모드
17
 
18
- @torch.no_grad()
19
- def infer_image(img_bytes: bytes):
20
- """
21
- 이미지 바이트를 받아서
22
- - 원본
23
- - 좌우 반전본
24
- 장을 모델에 넣은 뒤,
25
- 평균 logits로 softmax → top5 {label, score} 리스트 반환
26
- """
27
- base_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
 
 
 
28
 
29
- # 1) TTA용 이미지들
30
- tta_images = [
31
- base_image,
32
- base_image.transpose(Image.FLIP_LEFT_RIGHT),
33
- ]
34
-
35
- logits_sum = None
36
 
37
- for img in tta_images:
38
- inputs = processor(images=img, return_tensors="pt")
39
- outputs = model(**inputs)
40
- if logits_sum is None:
41
- logits_sum = outputs.logits
42
- else:
43
- logits_sum = logits_sum + outputs.logits
44
 
45
- # 2) 평균 logits에 softmax
46
- logits_mean = logits_sum / len(tta_images)
47
- probs = torch.nn.functional.softmax(logits_mean, dim=-1)[0]
 
 
 
 
 
 
 
48
 
49
- # 3) top5 뽑기
50
- values, indices = probs.topk(5)
51
  values = values.tolist()
52
  indices = indices.tolist()
53
 
54
  id2label = model.config.id2label
55
- result = []
56
  for score, idx in zip(values, indices):
57
- label = id2label.get(str(idx), f"Unknown_{idx}")
58
- result.append({
59
  "label": label,
60
- "score": float(score), # 0~1 사이 값
61
  })
62
- return result
63
 
64
 
65
  @app.post("/predict")
66
  async def predict(file: UploadFile = File(...)):
67
  """
68
  PHP에서 보내는 이미지 파일 하나를 받아서
69
- top5 결과를 그대로 JSON으로 반환
 
 
 
 
 
70
  """
71
- img_bytes = await file.read()
72
- result = infer_image(img_bytes)
73
- return JSONResponse(content=result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import io
7
 
8
+ # 토마토 잎 병해 전용 모델
9
  MODEL_ID = "wellCh4n/tomato-leaf-disease-classification-resnet50"
10
 
11
  app = FastAPI(title="SmartFarm Tomato Disease API")
12
 
13
+ # --- 모델 & 전처리기 로딩 (서버 시작 시 1번만) ---
14
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
15
  model = ResNetForImageClassification.from_pretrained(MODEL_ID)
16
  model.eval() # 추론 모드
17
 
18
+ # --- id2label 강제 오버라이드 (Unknown_* 방지용) ---
19
+ # 모델 config 안의 id2label이 이상하면 우리가 직접 지정한다.
20
+ custom_id2label = {
21
+ 0: "Tomato_healthy",
22
+ 1: "Tomato_Bacterial_spot",
23
+ 2: "Tomato_Early_blight",
24
+ 3: "Tomato_Late_blight",
25
+ 4: "Tomato_Leaf_Mold",
26
+ 5: "Tomato_Septoria_leaf_spot",
27
+ 6: "Tomato_Spider_mites_Two_spotted_spider_mite",
28
+ 7: "Tomato_Target_Spot",
29
+ 8: "Tomato_Tomato_Yellow_Leaf_Curl_Virus",
30
+ 9: "Tomato_Tomato_mosaic_virus",
31
+ }
32
 
33
+ # 모델 config 에도 반영 (혹시 내부에서 참조할 수도 있으니까)
34
+ model.config.id2label = {int(k): v for k, v in custom_id2label.items()}
35
+ model.config.label2id = {v: int(k) for k, v in custom_id2label.items()}
 
 
 
 
36
 
 
 
 
 
 
 
 
37
 
38
+ @torch.no_grad()
39
+ def infer_image(img_bytes: bytes, topk: int = 5):
40
+ """
41
+ 이미지 바이트 -> topk [{label, score}, ...] 반환
42
+ score는 0.0~1.0 사이 float
43
+ """
44
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
45
+ inputs = processor(images=image, return_tensors="pt")
46
+ outputs = model(**inputs)
47
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
48
 
49
+ values, indices = probs.topk(topk)
 
50
  values = values.tolist()
51
  indices = indices.tolist()
52
 
53
  id2label = model.config.id2label
54
+ results = []
55
  for score, idx in zip(values, indices):
56
+ label = id2label.get(int(idx), f"Unknown_{idx}")
57
+ results.append({
58
  "label": label,
59
+ "score": float(score),
60
  })
61
+ return results
62
 
63
 
64
  @app.post("/predict")
65
  async def predict(file: UploadFile = File(...)):
66
  """
67
  PHP에서 보내는 이미지 파일 하나를 받아서
68
+ HF Inference API와 비슷한 형식으로 결과 반환:
69
+ [
70
+ {"label": "...", "score": 0.87},
71
+ {"label": "...", "score": 0.05},
72
+ ...
73
+ ]
74
  """
75
+ try:
76
+ img_bytes = await file.read()
77
+ if not img_bytes:
78
+ return JSONResponse(
79
+ {"error": True, "message": "Empty file"},
80
+ status_code=400,
81
+ )
82
+
83
+ raw = infer_image(img_bytes, topk=5)
84
+ return JSONResponse(raw, status_code=200)
85
+
86
+ except Exception as e:
87
+ # 에러 나면 PHP에서 메시지 확인하기 쉽도록 문자열로 내려줌
88
+ return JSONResponse(
89
+ {"error": True, "message": str(e)},
90
+ status_code=500,
91
+ )