File size: 6,386 Bytes
537d300 76419e9 537d300 17d2f7c 537d300 76419e9 537d300 17d2f7c 537d300 76419e9 17d2f7c 76419e9 17d2f7c 76419e9 537d300 17d2f7c 537d300 17d2f7c 537d300 17d2f7c 537d300 17d2f7c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
from concurrent.futures import ProcessPoolExecutor
import asyncio
from app.model import run_inference
from app.schemas import PredictionResponse
from PIL import UnidentifiedImageError
app = FastAPI(title="ResNet-18 Image Classifier", version="1.0.0")
executor = ProcessPoolExecutor(max_workers=4)
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp", "image/gif"}
@app.get("/", response_class=HTMLResponse)
async def demo_ui():
# ... (HTML UI code remains the same)
return """
<!DOCTYPE html>
<html>
<head>
<title>ResNet-18 Image Classifier</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap" rel="stylesheet">
<style>
body { font-family: 'Inter', sans-serif; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; display: flex; align-items: center; justify-content: center; margin: 0; color: #333; }
.card { background: white; padding: 2rem; border-radius: 1.5rem; box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1); width: 100%; max-width: 450px; text-align: center; }
h1 { color: #1a202c; margin-bottom: 1.5rem; font-size: 1.5rem; }
#preview { width: 100%; height: 250px; border: 2px dashed #e2e8f0; border-radius: 1rem; object-fit: cover; margin-bottom: 1.5rem; display: none; }
.upload-btn-wrapper { position: relative; overflow: hidden; display: inline-block; width: 100%; }
.btn { border: none; color: white; background: #667eea; padding: 0.75rem 2rem; border-radius: 0.75rem; font-weight: 600; cursor: pointer; transition: 0.3s; width: 100%; font-size: 1rem; }
.btn:hover { background: #5a67d8; transform: translateY(-2px); }
input[type=file] { font-size: 100px; position: absolute; left: 0; top: 0; opacity: 0; cursor: pointer; }
#result { margin-top: 1.5rem; padding: 1rem; border-radius: 1rem; background: #f7fafc; display: none; text-align: left; }
.label { font-weight: 600; color: #4a5568; }
.value { color: #2d3748; float: right; }
.loading { display: none; margin-top: 1rem; }
</style>
</head>
<body>
<div class="card">
<h1>ResNet-18 Quantized</h1>
<img id="preview" src="#" alt="Preview">
<div class="upload-btn-wrapper">
<button class="btn" id="btn-text">Select Image to Predict</button>
<input type="file" id="fileInput" accept="image/*">
</div>
<div id="loading" class="loading">⌛ Predicting...</div>
<div id="result">
<div><span class="label">Label:</span> <span class="value" id="res-label">-</span></div>
<div style="margin-top:0.5rem"><span class="label">Confidence:</span> <span class="value" id="res-score">-</span></div>
<div style="margin-top:0.5rem"><span class="label">Inference:</span> <span class="value" id="res-time">-</span></div>
</div>
</div>
<script>
const fileInput = document.getElementById('fileInput');
const preview = document.getElementById('preview');
const resultDiv = document.getElementById('result');
const loading = document.getElementById('loading');
const btnText = document.getElementById('btn-text');
fileInput.onchange = evt => {
const [file] = fileInput.files;
if (file) {
preview.src = URL.createObjectURL(file);
preview.style.display = 'block';
predict(file);
}
}
async function predict(file) {
const formData = new FormData();
formData.append('file', file);
loading.style.display = 'block';
resultDiv.style.display = 'none';
btnText.disabled = true;
try {
const response = await fetch('/predict', { method: 'POST', body: formData });
const data = await response.json();
if (response.status !== 200) {
alert(data.detail || 'Prediction failed');
return;
}
document.getElementById('res-label').innerText = data.label;
document.getElementById('res-score').innerText = (data.score * 100).toFixed(2) + '%';
document.getElementById('res-time').innerText = data.inference_time_ms.toFixed(2) + ' ms';
resultDiv.style.display = 'block';
} catch (e) {
alert('Error predicting image');
} finally {
loading.style.display = 'none';
btnText.disabled = false;
}
}
</script>
</body>
</html>
"""
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
# 1. ตรวจสอบ Content Type
if file.content_type not in ALLOWED_CONTENT_TYPES:
raise HTTPException(status_code=415, detail="Unsupported media type")
# 2. อ่านข้อมูล
image_bytes = await file.read()
# 3. ตรวจสอบขนาดไฟล์ (Fix สำหรับ test_predict_rejects_oversized_file)
if len(image_bytes) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large")
# 4. รัน Inference และดักจับ Error (Fix สำหรับ test_predict_rejects_corrupted_file)
loop = asyncio.get_event_loop()
try:
result = await loop.run_in_executor(executor, run_inference, image_bytes)
return result
except UnidentifiedImageError:
raise HTTPException(status_code=400, detail="Invalid image file")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|