thonguyenp's picture
Update app.py
360b489 verified
from fastapi import FastAPI, Request, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from starlette.staticfiles import StaticFiles
import tensorflow as tf
import numpy as np
from PIL import Image
import io
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
import uvicorn
# Load mô hình TFLite
interpreter = tf.lite.Interpreter(model_path="model_wcpj_pro.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
class_names = ['wc_clean', 'wc_moderately_dirty', 'wc_slightly_dirty', 'wc_very_dirty'] # Thay bằng nhãn thực tế của bạn
# Khởi tạo ứng dụng FastAPI
app = FastAPI()
# Cho phép CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Cấu hình Jinja2 để render HTML
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Route trang chính
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# API tải ảnh lên và phân loại
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
image = Image.open(io.BytesIO(await file.read()))
image = image.resize((224, 224)) # Resize ảnh về đúng kích thước mô hình yêu cầu
image = np.array(image, dtype=np.float32) / 255.0 # Chuẩn hóa về [0,1]
image = np.expand_dims(image, axis=0) # Thêm batch dimension
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
predicted_class = np.argmax(output_data)
result = {class_names[i]: float(output_data[0][i]) for i in range(len(class_names))}
return JSONResponse(content={"prediction": result, "class": class_names[predicted_class]})
except Exception as e:
raise HTTPException(status_code=400, detail=f"Lỗi xử lý ảnh: {str(e)}")
# Chạy FastAPI
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)