Thompson001 commited on
Commit
ea814d3
ยท
verified ยท
1 Parent(s): 6b8d8ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -73
app.py CHANGED
@@ -1,106 +1,109 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
- import os
5
 
6
  # ------------------------------
7
- # 1) ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ ์ง€์ •
8
  # ------------------------------
9
  MODEL_PATH = "crack_detection.h5"
10
 
11
- IS_TF = MODEL_PATH.endswith(".h5") or MODEL_PATH.endswith(".keras")
12
-
13
  # ------------------------------
14
  # 2) TensorFlow ๋ชจ๋ธ ๋กœ๋“œ
15
  # ------------------------------
16
- if IS_TF:
17
- import tensorflow as tf
18
- model = tf.keras.models.load_model(MODEL_PATH)
19
- print("๐Ÿ”ฅ Loaded TensorFlow crack classifier")
20
-
21
- # ------------------------------
22
- # 3) PyTorch ๋ชจ๋ธ ๋กœ๋“œ
23
- # ------------------------------
24
- else:
25
- import torch
26
- from torch import nn
27
-
28
- class CNN(nn.Module):
29
- # ๋„ค๊ฐ€ ๊ฐ€์ง„ ๋ชจ๋ธ ๊ตฌ์กฐ ๋งž๊ฒŒ ์กฐ์ • ํ•„์š”
30
- def __init__(self):
31
- super().__init__()
32
- self.net = nn.Sequential(
33
- nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
34
- nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
35
- nn.Flatten(),
36
- nn.Linear(32 * 56 * 56, 2) # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ๋งž๊ฒŒ ์กฐ์ • ํ•„์š”
37
- )
38
- def forward(self, x):
39
- return self.net(x)
40
 
41
- model = CNN()
42
- model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
43
- model.eval()
44
- print("๐Ÿ”ฅ Loaded PyTorch crack classifier")
 
45
 
46
 
47
  # ------------------------------
48
- # 4) ์ „์ฒด ์˜ˆ์ธก ํ•จ์ˆ˜
 
49
  # ------------------------------
50
  def predict(img: Image.Image):
51
- # ์ž…๋ ฅ ๋ณ€ํ™˜
52
- img_resized = img.resize((224, 224))
53
- arr = np.array(img_resized) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- if IS_TF:
56
- # TensorFlow
57
- X = arr.reshape(1, 224, 224, 3)
58
- probs = model.predict(X)[0] # [p_normal, p_crack]
 
 
59
 
60
- else:
61
- # PyTorch
62
- import torch
63
- X = (
64
- torch.tensor(arr)
65
- .permute(2, 0, 1)
66
- .unsqueeze(0)
67
- .float()
68
- )
69
- probs = torch.softmax(model(X), dim=1).detach().numpy()[0]
70
 
71
- p_normal = float(probs[0])
72
- p_crack = float(probs[1])
 
 
 
 
 
 
 
 
 
73
 
74
- if p_crack > p_normal:
75
- label = "crack"
76
- conf = p_crack
77
- else:
78
- label = "normal"
79
- conf = p_normal
 
 
80
 
81
- # ------------------------------
82
- # ํ”„๋ก ํŠธ๊ฐ€ ์š”๊ตฌํ•˜๋Š” JSON ๊ตฌ์กฐ
83
- # ------------------------------
84
- return {
85
- "data": [
86
- {
87
- "label": label,
88
- "confidence": conf
89
- }
90
- ]
91
- }
 
 
 
92
 
93
 
94
  # ------------------------------
95
- # 5) Gradio API Interface
96
  # ------------------------------
97
  demo = gr.Interface(
98
  fn=predict,
99
- inputs=gr.Image(type="pil"),
100
  outputs=gr.JSON(label="Detection Result"),
101
- title="Crack Detection Classifier",
102
  description="์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜๋ฉด ๊ท ์—ด/์ •์ƒ ์—ฌ๋ถ€์™€ ํ™•๋ฅ (%)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.",
103
- flagging_mode="never"
104
  )
105
 
106
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
+ import tensorflow as tf
5
 
6
  # ------------------------------
7
+ # 1) ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
8
  # ------------------------------
9
  MODEL_PATH = "crack_detection.h5"
10
 
 
 
11
  # ------------------------------
12
  # 2) TensorFlow ๋ชจ๋ธ ๋กœ๋“œ
13
  # ------------------------------
14
+ model = tf.keras.models.load_model(MODEL_PATH)
15
+ print("๐Ÿ”ฅ Loaded TensorFlow crack classifier")
16
+ print(" Input shape :", model.input_shape)
17
+ print(" Output shape:", model.output_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # (None, H, W, C) ํ˜•ํƒœ๋ผ๊ณ  ๊ฐ€์ •
20
+ input_shape = model.input_shape
21
+ if len(input_shape) != 4:
22
+ raise ValueError(f"์˜ˆ์ƒ์น˜ ๋ชปํ•œ input_shape: {input_shape}")
23
+ _, H, W, C = input_shape
24
 
25
 
26
  # ------------------------------
27
+ # 3) ์˜ˆ์ธก ํ•จ์ˆ˜
28
+ # ํ•ญ์ƒ JSON์„ ๋ฆฌํ„ดํ•˜๋„๋ก try/except
29
  # ------------------------------
30
  def predict(img: Image.Image):
31
+ try:
32
+ # 1) ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
33
+ img = img.convert("RGB")
34
+ img_resized = img.resize((W, H))
35
+ arr = np.array(img_resized).astype("float32") / 255.0
36
+ X = np.expand_dims(arr, axis=0) # (1, H, W, C)
37
+
38
+ # 2) ๋ชจ๋ธ ์ถ”๋ก 
39
+ raw = model.predict(X)[0]
40
+ probs = np.array(raw, dtype="float32").flatten()
41
+
42
+ # 3) ์ถœ๋ ฅ ํ•ด์„
43
+ # - ๊ธธ์ด 1 : sigmoid โ†’ p_crack
44
+ # - ๊ธธ์ด 2+ : [p_normal, p_crack] ๊ฐ€์ •
45
+ if probs.shape[0] == 1:
46
+ p_crack = float(probs[0])
47
+ p_normal = 1.0 - p_crack
48
 
49
+ if p_crack >= 0.5:
50
+ label = "crack"
51
+ conf = p_crack
52
+ else:
53
+ label = "normal"
54
+ conf = p_normal
55
 
56
+ elif probs.shape[0] >= 2:
57
+ p_normal = float(probs[0])
58
+ p_crack = float(probs[1])
 
 
 
 
 
 
 
59
 
60
+ if p_crack >= p_normal:
61
+ label = "crack"
62
+ conf = p_crack
63
+ else:
64
+ label = "normal"
65
+ conf = p_normal
66
+ else:
67
+ # ๋งค์šฐ ํŠน์ดํ•œ ์ผ€์ด์Šค โ†’ ๊ทธ๋ƒฅ argmax
68
+ idx = int(np.argmax(probs))
69
+ label = f"class_{idx}"
70
+ conf = float(probs[idx])
71
 
72
+ return {
73
+ "data": [
74
+ {
75
+ "label": label,
76
+ "confidence": float(conf),
77
+ }
78
+ ]
79
+ }
80
 
81
+ except Exception as e:
82
+ # โ— ์—ฌ๊ธฐ์„œ ์˜ˆ์™ธ๋ฅผ ๋ชจ๋‘ ์žก์•„์„œ JSON์œผ๋กœ ๋‚ด๋ ค์คŒ
83
+ # ์ด๋ ‡๊ฒŒ ํ•ด์•ผ HF Space๊ฐ€ 500 ์•ˆ ๋˜์ง€๊ณ ,
84
+ # ํ”„๋ก ํŠธ์—์„œ Raw Response/JSON Payload๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Œ.
85
+ print("โŒ Error in predict():", repr(e))
86
+ return {
87
+ "data": [
88
+ {
89
+ "label": "error",
90
+ "confidence": 0.0,
91
+ "message": str(e),
92
+ }
93
+ ]
94
+ }
95
 
96
 
97
  # ------------------------------
98
+ # 4) Gradio Interface
99
  # ------------------------------
100
  demo = gr.Interface(
101
  fn=predict,
102
+ inputs=gr.Image(type="pil", label="Input image"),
103
  outputs=gr.JSON(label="Detection Result"),
104
+ title="Crack Detection Classifier (Keras .h5)",
105
  description="์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜๋ฉด ๊ท ์—ด/์ •์ƒ ์—ฌ๋ถ€์™€ ํ™•๋ฅ (%)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.",
106
+ flagging_mode="never",
107
  )
108
 
109
  if __name__ == "__main__":