crack-api / app.py
Thompson001's picture
Update app.py
6b8d8ce verified
raw
history blame
2.83 kB
import gradio as gr
import numpy as np
from PIL import Image
import os
# ------------------------------
# 1) ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ ์ง€์ •
# ------------------------------
MODEL_PATH = "crack_detection.h5"
IS_TF = MODEL_PATH.endswith(".h5") or MODEL_PATH.endswith(".keras")
# ------------------------------
# 2) TensorFlow ๋ชจ๋ธ ๋กœ๋“œ
# ------------------------------
if IS_TF:
import tensorflow as tf
model = tf.keras.models.load_model(MODEL_PATH)
print("๐Ÿ”ฅ Loaded TensorFlow crack classifier")
# ------------------------------
# 3) PyTorch ๋ชจ๋ธ ๋กœ๋“œ
# ------------------------------
else:
import torch
from torch import nn
class CNN(nn.Module):
# ๋„ค๊ฐ€ ๊ฐ€์ง„ ๋ชจ๋ธ ๊ตฌ์กฐ ๋งž๊ฒŒ ์กฐ์ • ํ•„์š”
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(32 * 56 * 56, 2) # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ๋งž๊ฒŒ ์กฐ์ • ํ•„์š”
)
def forward(self, x):
return self.net(x)
model = CNN()
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
print("๐Ÿ”ฅ Loaded PyTorch crack classifier")
# ------------------------------
# 4) ์ „์ฒด ์˜ˆ์ธก ํ•จ์ˆ˜
# ------------------------------
def predict(img: Image.Image):
# ์ž…๋ ฅ ๋ณ€ํ™˜
img_resized = img.resize((224, 224))
arr = np.array(img_resized) / 255.0
if IS_TF:
# TensorFlow
X = arr.reshape(1, 224, 224, 3)
probs = model.predict(X)[0] # [p_normal, p_crack]
else:
# PyTorch
import torch
X = (
torch.tensor(arr)
.permute(2, 0, 1)
.unsqueeze(0)
.float()
)
probs = torch.softmax(model(X), dim=1).detach().numpy()[0]
p_normal = float(probs[0])
p_crack = float(probs[1])
if p_crack > p_normal:
label = "crack"
conf = p_crack
else:
label = "normal"
conf = p_normal
# ------------------------------
# ํ”„๋ก ํŠธ๊ฐ€ ์š”๊ตฌํ•˜๋Š” JSON ๊ตฌ์กฐ
# ------------------------------
return {
"data": [
{
"label": label,
"confidence": conf
}
]
}
# ------------------------------
# 5) Gradio API Interface
# ------------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.JSON(label="Detection Result"),
title="Crack Detection Classifier",
description="์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜๋ฉด ๊ท ์—ด/์ •์ƒ ์—ฌ๋ถ€์™€ ํ™•๋ฅ (%)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.",
flagging_mode="never"
)
if __name__ == "__main__":
demo.launch()