from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
from concurrent.futures import ProcessPoolExecutor
import asyncio
from app.model import run_inference
from app.schemas import PredictionResponse
from PIL import UnidentifiedImageError
app = FastAPI(title="ResNet-18 Image Classifier", version="1.0.0")
executor = ProcessPoolExecutor(max_workers=4)
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp", "image/gif"}
@app.get("/", response_class=HTMLResponse)
async def demo_ui():
# ... (HTML UI code remains the same)
return """
ResNet-18 Image Classifier
ResNet-18 Quantized
⌛ Predicting...
Label: -
Confidence: -
Inference: -
"""
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
# 1. ตรวจสอบ Content Type
if file.content_type not in ALLOWED_CONTENT_TYPES:
raise HTTPException(status_code=415, detail="Unsupported media type")
# 2. อ่านข้อมูล
image_bytes = await file.read()
# 3. ตรวจสอบขนาดไฟล์ (Fix สำหรับ test_predict_rejects_oversized_file)
if len(image_bytes) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large")
# 4. รัน Inference และดักจับ Error (Fix สำหรับ test_predict_rejects_corrupted_file)
loop = asyncio.get_event_loop()
try:
result = await loop.run_in_executor(executor, run_inference, image_bytes)
return result
except UnidentifiedImageError:
raise HTTPException(status_code=400, detail="Invalid image file")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")