from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse from starlette.staticfiles import StaticFiles import gradio as gr import tensorflow as tf import numpy as np from PIL import Image from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates # Load mô hình TFLite interpreter = tf.lite.Interpreter(model_path="model_wcpj_pro.tflite") interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() class_names = ['wc_clean', 'wc_moderately_dirty', 'wc_slightly_dirty', 'wc_very_dirty'] # Thay bằng nhãn của bạn # Khởi tạo ứng dụng FastAPI app = FastAPI() # Cho phép CORS (nếu cần) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Cấu hình Jinja2 để render HTML templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") # Route trang chính @app.get("/", response_class=HTMLResponse) async def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) # Hàm xử lý ảnh & dự đoán def predict(image): image = image.resize((224, 224)) # Resize ảnh image = np.array(image, dtype=np.float32) / 255.0 # Chuẩn hóa image = np.expand_dims(image, axis=0) # Thêm batch dimension interpreter.set_tensor(input_details[0]['index'], image) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) predicted_class = np.argmax(output_data) return {class_names[i]: float(output_data[0][i]) for i in range(len(class_names))} # API Gradio interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(), title="WCPJ Floor Classification", description="Tải ảnh sàn nhà vệ sinh lên để phân loại" ) # Chạy Gradio trên FastAPI @app.get("/gradio") async def gradio_interface(): return interface.launch(share=True, inline=False)