thonguyenp commited on
Commit
360b489
·
verified ·
1 Parent(s): f6434f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -56
app.py CHANGED
@@ -1,16 +1,13 @@
1
- from fastapi import FastAPI, Request, File, UploadFile
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.templating import Jinja2Templates
6
- import gradio as gr
7
  import tensorflow as tf
8
  import numpy as np
9
  from PIL import Image
 
 
 
10
  import uvicorn
11
- import os
12
- import shutil
13
- import threading
14
 
15
  # Load mô hình TFLite
16
  interpreter = tf.lite.Interpreter(model_path="model_wcpj_pro.tflite")
@@ -19,9 +16,9 @@ interpreter.allocate_tensors()
19
  input_details = interpreter.get_input_details()
20
  output_details = interpreter.get_output_details()
21
 
22
- class_names = ['wc_clean', 'wc_moderately_dirty', 'wc_slightly_dirty', 'wc_very_dirty']
23
 
24
- # Khởi tạo FastAPI
25
  app = FastAPI()
26
 
27
  # Cho phép CORS
@@ -37,56 +34,30 @@ app.add_middleware(
37
  templates = Jinja2Templates(directory="templates")
38
  app.mount("/static", StaticFiles(directory="static"), name="static")
39
 
40
- # Tạo thư mục lưu ảnh nếu chưa có
41
- UPLOAD_FOLDER = "uploads"
42
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
43
-
44
  # Route trang chính
45
  @app.get("/", response_class=HTMLResponse)
46
  async def read_root(request: Request):
47
  return templates.TemplateResponse("index.html", {"request": request})
48
 
49
- # API tải ảnh lên Hugging Face
50
- @app.post("/upload/")
51
- async def upload_file(file: UploadFile = File(...)):
52
- file_location = f"{UPLOAD_FOLDER}/{file.filename}"
53
- with open(file_location, "wb") as buffer:
54
- shutil.copyfileobj(file.file, buffer)
55
- return {"message": "Tải ảnh lên thành công!", "file_name": file.filename}
56
-
57
- # Hàm xử lý ảnh & dự đoán
58
- def predict(image):
59
- image = image.resize((224, 224)) # Resize ảnh về kích thước chuẩn của mô hình
60
- image = np.array(image, dtype=np.float32) / 255.0 # Chuẩn hóa ảnh về [0,1]
61
- image = np.expand_dims(image, axis=0) # Thêm batch dimension
62
-
63
- interpreter.set_tensor(input_details[0]['index'], image)
64
- interpreter.invoke()
65
- output_data = interpreter.get_tensor(output_details[0]['index'])
66
-
67
- predicted_class = np.argmax(output_data)
68
- return {class_names[i]: float(output_data[0][i]) for i in range(len(class_names))}
69
-
70
- # Tạo giao diện Gradio
71
- interface = gr.Interface(
72
- fn=predict,
73
- inputs=gr.Image(type="pil"),
74
- outputs=gr.Label(),
75
- title="WCPJ Floor Classification",
76
- description="Tải ảnh sàn nhà vệ sinh lên để phân loại"
77
- )
78
-
79
- # Chạy Gradio trên một thread riêng
80
- def run_gradio():
81
- interface.launch(server_name="0.0.0.0", server_port=7861, share=False)
82
-
83
- thread = threading.Thread(target=run_gradio)
84
- thread.start()
85
-
86
- # Route nhúng Gradio vào iframe
87
- @app.get("/gradio", response_class=HTMLResponse)
88
- async def gradio_page(request: Request):
89
- return templates.TemplateResponse("gradio.html", {"request": request})
90
 
91
  # Chạy FastAPI
92
  if __name__ == "__main__":
 
1
+ from fastapi import FastAPI, Request, File, UploadFile, HTTPException
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
+ from starlette.staticfiles import StaticFiles
 
 
 
4
  import tensorflow as tf
5
  import numpy as np
6
  from PIL import Image
7
+ import io
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.templating import Jinja2Templates
10
  import uvicorn
 
 
 
11
 
12
  # Load mô hình TFLite
13
  interpreter = tf.lite.Interpreter(model_path="model_wcpj_pro.tflite")
 
16
  input_details = interpreter.get_input_details()
17
  output_details = interpreter.get_output_details()
18
 
19
+ class_names = ['wc_clean', 'wc_moderately_dirty', 'wc_slightly_dirty', 'wc_very_dirty'] # Thay bằng nhãn thực tế của bạn
20
 
21
+ # Khởi tạo ứng dụng FastAPI
22
  app = FastAPI()
23
 
24
  # Cho phép CORS
 
34
  templates = Jinja2Templates(directory="templates")
35
  app.mount("/static", StaticFiles(directory="static"), name="static")
36
 
 
 
 
 
37
  # Route trang chính
38
  @app.get("/", response_class=HTMLResponse)
39
  async def read_root(request: Request):
40
  return templates.TemplateResponse("index.html", {"request": request})
41
 
42
+ # API tải ảnh lên phân loại
43
+ @app.post("/predict/")
44
+ async def predict(file: UploadFile = File(...)):
45
+ try:
46
+ image = Image.open(io.BytesIO(await file.read()))
47
+ image = image.resize((224, 224)) # Resize ảnh về đúng kích thước mô hình yêu cầu
48
+ image = np.array(image, dtype=np.float32) / 255.0 # Chuẩn hóa về [0,1]
49
+ image = np.expand_dims(image, axis=0) # Thêm batch dimension
50
+
51
+ interpreter.set_tensor(input_details[0]['index'], image)
52
+ interpreter.invoke()
53
+ output_data = interpreter.get_tensor(output_details[0]['index'])
54
+
55
+ predicted_class = np.argmax(output_data)
56
+ result = {class_names[i]: float(output_data[0][i]) for i in range(len(class_names))}
57
+ return JSONResponse(content={"prediction": result, "class": class_names[predicted_class]})
58
+
59
+ except Exception as e:
60
+ raise HTTPException(status_code=400, detail=f"Lỗi xử lý ảnh: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Chạy FastAPI
63
  if __name__ == "__main__":