Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| # ------------------------------ | |
| # 1) ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก ์ง์ | |
| # ------------------------------ | |
| MODEL_PATH = "crack_detection.h5" | |
| IS_TF = MODEL_PATH.endswith(".h5") or MODEL_PATH.endswith(".keras") | |
| # ------------------------------ | |
| # 2) TensorFlow ๋ชจ๋ธ ๋ก๋ | |
| # ------------------------------ | |
| if IS_TF: | |
| import tensorflow as tf | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| print("๐ฅ Loaded TensorFlow crack classifier") | |
| # ------------------------------ | |
| # 3) PyTorch ๋ชจ๋ธ ๋ก๋ | |
| # ------------------------------ | |
| else: | |
| import torch | |
| from torch import nn | |
| class CNN(nn.Module): | |
| # ๋ค๊ฐ ๊ฐ์ง ๋ชจ๋ธ ๊ตฌ์กฐ ๋ง๊ฒ ์กฐ์ ํ์ | |
| def __init__(self): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Flatten(), | |
| nn.Linear(32 * 56 * 56, 2) # ์ด๋ฏธ์ง ํฌ๊ธฐ ๋ง๊ฒ ์กฐ์ ํ์ | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| model = CNN() | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) | |
| model.eval() | |
| print("๐ฅ Loaded PyTorch crack classifier") | |
| # ------------------------------ | |
| # 4) ์ ์ฒด ์์ธก ํจ์ | |
| # ------------------------------ | |
| def predict(img: Image.Image): | |
| # ์ ๋ ฅ ๋ณํ | |
| img_resized = img.resize((224, 224)) | |
| arr = np.array(img_resized) / 255.0 | |
| if IS_TF: | |
| # TensorFlow | |
| X = arr.reshape(1, 224, 224, 3) | |
| probs = model.predict(X)[0] # [p_normal, p_crack] | |
| else: | |
| # PyTorch | |
| import torch | |
| X = ( | |
| torch.tensor(arr) | |
| .permute(2, 0, 1) | |
| .unsqueeze(0) | |
| .float() | |
| ) | |
| probs = torch.softmax(model(X), dim=1).detach().numpy()[0] | |
| p_normal = float(probs[0]) | |
| p_crack = float(probs[1]) | |
| if p_crack > p_normal: | |
| label = "crack" | |
| conf = p_crack | |
| else: | |
| label = "normal" | |
| conf = p_normal | |
| # ------------------------------ | |
| # ํ๋ก ํธ๊ฐ ์๊ตฌํ๋ JSON ๊ตฌ์กฐ | |
| # ------------------------------ | |
| return { | |
| "data": [ | |
| { | |
| "label": label, | |
| "confidence": conf | |
| } | |
| ] | |
| } | |
| # ------------------------------ | |
| # 5) Gradio API Interface | |
| # ------------------------------ | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.JSON(label="Detection Result"), | |
| title="Crack Detection Classifier", | |
| description="์ฌ์ง์ ์ ๋ก๋ํ๋ฉด ๊ท ์ด/์ ์ ์ฌ๋ถ์ ํ๋ฅ (%)์ ๋ฐํํฉ๋๋ค.", | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |