PhonePixelGhost commited on
Commit
76419e9
·
verified ·
1 Parent(s): 537d300

Upload main.py

Browse files
Files changed (1) hide show
  1. app/main.py +89 -15
app/main.py CHANGED
@@ -1,17 +1,100 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
2
  from concurrent.futures import ProcessPoolExecutor
3
  import asyncio
4
  from app.model import run_inference
5
  from app.schemas import PredictionResponse
6
 
7
  app = FastAPI(title="ResNet-18 Image Classifier", version="1.0.0")
 
8
 
9
- # ใช้ 6 Workers ตามจำนวน Physical Cores ของ Ryzen 7500F
10
- executor = ProcessPoolExecutor(max_workers=6)
11
-
12
- MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
13
  ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp", "image/gif"}
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @app.get("/health")
16
  async def health():
17
  return {"status": "ok"}
@@ -20,16 +103,7 @@ async def health():
20
  async def predict(file: UploadFile = File(...)):
21
  if file.content_type not in ALLOWED_CONTENT_TYPES:
22
  raise HTTPException(status_code=415, detail="Unsupported media type")
23
-
24
  image_bytes = await file.read()
25
-
26
- if len(image_bytes) > MAX_FILE_SIZE:
27
- raise HTTPException(status_code=413, detail="File too large")
28
-
29
- # รัน Inference ใน ProcessPoolExecutor เพื่อกระจายโหลดลง 6 Cores
30
  loop = asyncio.get_event_loop()
31
- try:
32
- result = await loop.run_in_executor(executor, run_inference, image_bytes)
33
- return result
34
- except Exception as e:
35
- raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import HTMLResponse
3
  from concurrent.futures import ProcessPoolExecutor
4
  import asyncio
5
  from app.model import run_inference
6
  from app.schemas import PredictionResponse
7
 
8
  app = FastAPI(title="ResNet-18 Image Classifier", version="1.0.0")
9
+ executor = ProcessPoolExecutor(max_workers=4)
10
 
 
 
 
 
11
  ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp", "image/gif"}
12
 
13
+ @app.get("/", response_class=HTMLResponse)
14
+ async def demo_ui():
15
+ return """
16
+ <!DOCTYPE html>
17
+ <html>
18
+ <head>
19
+ <title>ResNet-18 Image Classifier</title>
20
+ <meta name="viewport" content="width=device-width, initial-scale=1">
21
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap" rel="stylesheet">
22
+ <style>
23
+ body { font-family: 'Inter', sans-serif; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; display: flex; align-items: center; justify-content: center; margin: 0; color: #333; }
24
+ .card { background: white; padding: 2rem; border-radius: 1.5rem; box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1); width: 100%; max-width: 450px; text-align: center; }
25
+ h1 { color: #1a202c; margin-bottom: 1.5rem; font-size: 1.5rem; }
26
+ #preview { width: 100%; height: 250px; border: 2px dashed #e2e8f0; border-radius: 1rem; object-fit: cover; margin-bottom: 1.5rem; display: none; }
27
+ .upload-btn-wrapper { position: relative; overflow: hidden; display: inline-block; width: 100%; }
28
+ .btn { border: none; color: white; background: #667eea; padding: 0.75rem 2rem; border-radius: 0.75rem; font-weight: 600; cursor: pointer; transition: 0.3s; width: 100%; font-size: 1rem; }
29
+ .btn:hover { background: #5a67d8; transform: translateY(-2px); }
30
+ input[type=file] { font-size: 100px; position: absolute; left: 0; top: 0; opacity: 0; cursor: pointer; }
31
+ #result { margin-top: 1.5rem; padding: 1rem; border-radius: 1rem; background: #f7fafc; display: none; text-align: left; }
32
+ .label { font-weight: 600; color: #4a5568; }
33
+ .value { color: #2d3748; float: right; }
34
+ .loading { display: none; margin-top: 1rem; }
35
+ </style>
36
+ </head>
37
+ <body>
38
+ <div class="card">
39
+ <h1>ResNet-18 Quantized</h1>
40
+ <img id="preview" src="#" alt="Preview">
41
+ <div class="upload-btn-wrapper">
42
+ <button class="btn" id="btn-text">Select Image to Predict</button>
43
+ <input type="file" id="fileInput" accept="image/*">
44
+ </div>
45
+ <div id="loading" class="loading">⌛ Predicting...</div>
46
+ <div id="result">
47
+ <div><span class="label">Label:</span> <span class="value" id="res-label">-</span></div>
48
+ <div style="margin-top:0.5rem"><span class="label">Confidence:</span> <span class="value" id="res-score">-</span></div>
49
+ <div style="margin-top:0.5rem"><span class="label">Inference:</span> <span class="value" id="res-time">-</span></div>
50
+ </div>
51
+ </div>
52
+
53
+ <script>
54
+ const fileInput = document.getElementById('fileInput');
55
+ const preview = document.getElementById('preview');
56
+ const resultDiv = document.getElementById('result');
57
+ const loading = document.getElementById('loading');
58
+ const btnText = document.getElementById('btn-text');
59
+
60
+ fileInput.onchange = evt => {
61
+ const [file] = fileInput.files;
62
+ if (file) {
63
+ preview.src = URL.createObjectURL(file);
64
+ preview.style.display = 'block';
65
+ predict(file);
66
+ }
67
+ }
68
+
69
+ async function predict(file) {
70
+ const formData = new FormData();
71
+ formData.append('file', file);
72
+
73
+ loading.style.display = 'block';
74
+ resultDiv.style.display = 'none';
75
+ btnText.disabled = true;
76
+
77
+ try {
78
+ const response = await fetch('/predict', { method: 'POST', body: formData });
79
+ const data = await response.json();
80
+
81
+ document.getElementById('res-label').innerText = data.label;
82
+ document.getElementById('res-score').innerText = (data.score * 100).toFixed(2) + '%';
83
+ document.getElementById('res-time').innerText = data.inference_time_ms.toFixed(2) + ' ms';
84
+
85
+ resultDiv.style.display = 'block';
86
+ } catch (e) {
87
+ alert('Error predicting image');
88
+ } finally {
89
+ loading.style.display = 'none';
90
+ btnText.disabled = false;
91
+ }
92
+ }
93
+ </script>
94
+ </body>
95
+ </html>
96
+ """
97
+
98
  @app.get("/health")
99
  async def health():
100
  return {"status": "ok"}
 
103
  async def predict(file: UploadFile = File(...)):
104
  if file.content_type not in ALLOWED_CONTENT_TYPES:
105
  raise HTTPException(status_code=415, detail="Unsupported media type")
 
106
  image_bytes = await file.read()
 
 
 
 
 
107
  loop = asyncio.get_event_loop()
108
+ result = await loop.run_in_executor(executor, run_inference, image_bytes)
109
+ return result