Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, File, UploadFile, HTTPException | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from starlette.staticfiles import StaticFiles | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.templating import Jinja2Templates | |
| import uvicorn | |
| # Load mô hình TFLite | |
| interpreter = tf.lite.Interpreter(model_path="model_wcpj_pro.tflite") | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| class_names = ['wc_clean', 'wc_moderately_dirty', 'wc_slightly_dirty', 'wc_very_dirty'] # Thay bằng nhãn thực tế của bạn | |
| # Khởi tạo ứng dụng FastAPI | |
| app = FastAPI() | |
| # Cho phép CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Cấu hình Jinja2 để render HTML | |
| templates = Jinja2Templates(directory="templates") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Route trang chính | |
| async def read_root(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # API tải ảnh lên và phân loại | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| image = Image.open(io.BytesIO(await file.read())) | |
| image = image.resize((224, 224)) # Resize ảnh về đúng kích thước mô hình yêu cầu | |
| image = np.array(image, dtype=np.float32) / 255.0 # Chuẩn hóa về [0,1] | |
| image = np.expand_dims(image, axis=0) # Thêm batch dimension | |
| interpreter.set_tensor(input_details[0]['index'], image) | |
| interpreter.invoke() | |
| output_data = interpreter.get_tensor(output_details[0]['index']) | |
| predicted_class = np.argmax(output_data) | |
| result = {class_names[i]: float(output_data[0][i]) for i in range(len(class_names))} | |
| return JSONResponse(content={"prediction": result, "class": class_names[predicted_class]}) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Lỗi xử lý ảnh: {str(e)}") | |
| # Chạy FastAPI | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |