model_wcpj / app.py
thonguyenp's picture
Upload app.py
d7117e2 verified
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from starlette.staticfiles import StaticFiles
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
# Load mô hình TFLite
interpreter = tf.lite.Interpreter(model_path="model_wcpj_pro.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
class_names = ['wc_clean', 'wc_moderately_dirty', 'wc_slightly_dirty', 'wc_very_dirty'] # Thay bằng nhãn của bạn
# Khởi tạo ứng dụng FastAPI
app = FastAPI()
# Cho phép CORS (nếu cần)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Cấu hình Jinja2 để render HTML
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Route trang chính
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# Hàm xử lý ảnh & dự đoán
def predict(image):
image = image.resize((224, 224)) # Resize ảnh
image = np.array(image, dtype=np.float32) / 255.0 # Chuẩn hóa
image = np.expand_dims(image, axis=0) # Thêm batch dimension
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
predicted_class = np.argmax(output_data)
return {class_names[i]: float(output_data[0][i]) for i in range(len(class_names))}
# API Gradio
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(),
title="WCPJ Floor Classification",
description="Tải ảnh sàn nhà vệ sinh lên để phân loại"
)
# Chạy Gradio trên FastAPI
@app.get("/gradio")
async def gradio_interface():
return interface.launch(share=True, inline=False)