pfox1995 commited on
Commit
efb4788
Β·
verified Β·
1 Parent(s): 722e573

Add server.py (server.py + restart_server.sh + Korean README expansion)

Browse files
Files changed (1) hide show
  1. server.py +264 -0
server.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """FastAPI server for the Korean pest detector.
3
+
4
+ Wraps the validated Unsloth FastVisionModel + PEFT runtime LoRA setup
5
+ (load_in_4bit=True by default β†’ ~8.7 GB VRAM).
6
+
7
+ Endpoints:
8
+ GET /health β†’ {"status": "ok", "model_loaded": bool}
9
+ GET /classes β†’ ["κ²€κ±°μ„Έλ―Έλ°€λ‚˜λ°©", ...] (19 classes)
10
+ GET / β†’ minimal HTML upload form
11
+ POST /classify β†’ multipart file OR JSON {"image": "<base64>"}
12
+ returns {"pred": ..., "raw": ..., "elapsed_s": ..., "all_classes": [...]}
13
+
14
+ Env:
15
+ BASE_MODEL default: unsloth/Qwen3.5-9B
16
+ ADAPTER default: pfox1995/pest-detector-deploy
17
+ LOAD_IN_4BIT "true"/"false" (default: true)
18
+ PORT default: 8080
19
+
20
+ Usage:
21
+ python server.py
22
+ """
23
+ import base64
24
+ import io
25
+ import os
26
+ import time
27
+ from contextlib import asynccontextmanager
28
+ from typing import Optional
29
+
30
+ import torch
31
+ import uvicorn
32
+ from fastapi import FastAPI, File, UploadFile, HTTPException
33
+ from fastapi.responses import HTMLResponse, JSONResponse
34
+ from PIL import Image
35
+ from pydantic import BaseModel
36
+
37
+ # ─── Constants from training (DO NOT change) ─────────────────────────────
38
+ PEST_CLASSES = [
39
+ "κ²€κ±°μ„Έλ―Έλ°€λ‚˜λ°©", "κ½ƒλ…Έλž‘μ΄μ±„λ²Œλ ˆ", "담배가루이", "λ‹΄λ°°κ±°μ„Έλ―Έλ‚˜λ°©",
40
+ "λ‹΄λ°°λ‚˜λ°©", "λ„λ‘‘λ‚˜λ°©", "λ¨Ήλ…Έλ¦°μž¬", "λͺ©ν™”λ°”λ‘‘λͺ…λ‚˜λ°©", "무잎벌",
41
+ "λ°°μΆ”μ’€λ‚˜λ°©", "λ°°μΆ”ν°λ‚˜λΉ„", "벼룩잎벌레", "λΉ„λ‹¨λ…Έλ¦°μž¬", "μ©λ©λ‚˜λ¬΄λ…Έλ¦°μž¬",
42
+ "μ•Œλ½μˆ˜μ—Όλ…Έλ¦°μž¬", "정상", "큰28μ λ°•μ΄λ¬΄λ‹Ήλ²Œλ ˆ", "ν†±λ‹€λ¦¬κ°œλ―Έν—ˆλ¦¬λ…Έλ¦°μž¬",
43
+ "νŒŒλ°€λ‚˜λ°©",
44
+ ]
45
+ SYSTEM_MSG = (
46
+ "당신은 μž‘λ¬Ό ν•΄μΆ© 식별 μ „λ¬Έκ°€μž…λ‹ˆλ‹€. "
47
+ "사진을 보고 ν•΄μΆ©μ˜ μ΄λ¦„λ§Œ ν•œκ΅­μ–΄λ‘œ λ‹΅ν•˜μ„Έμš”. "
48
+ '해좩이 μ—†μœΌλ©΄ "정상"이라고만 λ‹΅ν•˜μ„Έμš”. '
49
+ "λΆ€κ°€ μ„€λͺ… 없이 μ΄λ¦„λ§Œ 좜λ ₯ν•˜μ„Έμš”."
50
+ )
51
+ USER_PROMPT = "이 사진에 μžˆλŠ” ν•΄μΆ©μ˜ 이름을 μ•Œλ €μ£Όμ„Έμš”."
52
+ LETTERBOX_SIZE = 512
53
+ LETTERBOX_FILL = (128, 128, 128)
54
+
55
+
56
+ def letterbox(img: Image.Image, size: int = LETTERBOX_SIZE) -> Image.Image:
57
+ img = img.convert("RGB")
58
+ w, h = img.size
59
+ scale = size / max(w, h)
60
+ nw, nh = int(round(w * scale)), int(round(h * scale))
61
+ resized = img.resize((nw, nh), Image.Resampling.LANCZOS)
62
+ canvas = Image.new("RGB", (size, size), LETTERBOX_FILL)
63
+ canvas.paste(resized, ((size - nw) // 2, (size - nh) // 2))
64
+ return canvas
65
+
66
+
67
+ # ─── Model state ─────────────────────────────────────────────────────────
68
+ class ModelState:
69
+ model = None
70
+ tokenizer = None
71
+ text_tokenizer = None # underlying transformers tokenizer (for stop_strings=)
72
+
73
+
74
+ STATE = ModelState()
75
+
76
+
77
+ def load_model():
78
+ from unsloth import FastVisionModel
79
+ from peft import PeftModel
80
+ from huggingface_hub import snapshot_download
81
+
82
+ base = os.environ.get("BASE_MODEL", "unsloth/Qwen3.5-9B")
83
+ adapter = os.environ.get("ADAPTER", "pfox1995/pest-detector-deploy")
84
+ four_bit = os.environ.get("LOAD_IN_4BIT", "true").lower() == "true"
85
+
86
+ if os.environ.get("HF_TOKEN"):
87
+ from huggingface_hub import login
88
+ login(token=os.environ["HF_TOKEN"], add_to_git_credential=False)
89
+
90
+ print(f"[startup] FastVisionModel.from_pretrained({base}, load_in_4bit={four_bit})", flush=True)
91
+ t0 = time.time()
92
+ model, tok = FastVisionModel.from_pretrained(base, load_in_4bit=four_bit)
93
+ print(f"[startup] loaded base in {time.time()-t0:.1f}s; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)
94
+
95
+ adapter_dir = adapter if os.path.isdir(adapter) else snapshot_download(repo_id=adapter)
96
+ print(f"[startup] attaching LoRA: {adapter_dir}", flush=True)
97
+ model = PeftModel.from_pretrained(model, adapter_dir)
98
+ FastVisionModel.for_inference(model)
99
+ model.eval()
100
+ print(f"[startup] ready; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)
101
+
102
+ STATE.model = model
103
+ STATE.tokenizer = tok
104
+ STATE.text_tokenizer = tok.tokenizer if hasattr(tok, "tokenizer") else tok
105
+
106
+
107
+ def classify_image(img: Image.Image) -> dict:
108
+ if STATE.model is None:
109
+ raise RuntimeError("Model not loaded")
110
+ image = letterbox(img)
111
+ messages = [
112
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]},
113
+ {"role": "user", "content": [
114
+ {"type": "image", "image": image},
115
+ {"type": "text", "text": USER_PROMPT},
116
+ ]},
117
+ ]
118
+ text = STATE.tokenizer.apply_chat_template(
119
+ messages, add_generation_prompt=True, enable_thinking=False,
120
+ )
121
+ inputs = STATE.tokenizer(
122
+ image, text, add_special_tokens=False, return_tensors="pt",
123
+ ).to("cuda")
124
+
125
+ t0 = time.time()
126
+ with torch.inference_mode():
127
+ out = STATE.model.generate(
128
+ **inputs,
129
+ max_new_tokens=10,
130
+ use_cache=True,
131
+ stop_strings=["\n"],
132
+ tokenizer=STATE.text_tokenizer,
133
+ )
134
+ elapsed = time.time() - t0
135
+ raw = STATE.tokenizer.decode(
136
+ out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True,
137
+ ).strip()
138
+ pred = raw if raw in PEST_CLASSES else None
139
+ if pred is None:
140
+ for c in sorted(PEST_CLASSES, key=len, reverse=True):
141
+ if raw.startswith(c):
142
+ pred = c
143
+ break
144
+ if pred is None:
145
+ pred = raw # surface raw text if no class match (debugging signal)
146
+ return {"pred": pred, "raw": raw, "elapsed_s": round(elapsed, 3)}
147
+
148
+
149
+ # ─── FastAPI app ─────────────────────────────────────────────────────────
150
+ @asynccontextmanager
151
+ async def lifespan(app: FastAPI):
152
+ load_model()
153
+ yield
154
+ # nothing to clean up
155
+
156
+
157
+ app = FastAPI(
158
+ title="Korean Pest Detector",
159
+ description="Qwen3.5-9B + LoRA via Unsloth + PEFT runtime",
160
+ lifespan=lifespan,
161
+ )
162
+
163
+
164
+ @app.get("/health")
165
+ def health():
166
+ return {"status": "ok", "model_loaded": STATE.model is not None}
167
+
168
+
169
+ @app.get("/classes")
170
+ def classes():
171
+ return {"classes": PEST_CLASSES, "count": len(PEST_CLASSES)}
172
+
173
+
174
+ class ClassifyJSON(BaseModel):
175
+ image: str # base64-encoded image bytes
176
+
177
+
178
+ @app.post("/classify")
179
+ async def classify(
180
+ file: Optional[UploadFile] = File(None),
181
+ ):
182
+ """Accepts multipart 'file' upload."""
183
+ if file is None:
184
+ raise HTTPException(400, "Provide 'file' multipart field, or POST JSON to /classify_b64")
185
+ try:
186
+ img_bytes = await file.read()
187
+ img = Image.open(io.BytesIO(img_bytes))
188
+ except Exception as e:
189
+ raise HTTPException(400, f"could not parse image: {e}")
190
+ try:
191
+ return JSONResponse(classify_image(img))
192
+ except Exception as e:
193
+ raise HTTPException(500, f"inference error: {e}")
194
+
195
+
196
+ @app.post("/classify_b64")
197
+ async def classify_b64(payload: ClassifyJSON):
198
+ """Accepts JSON {"image": "<base64-encoded image>"}."""
199
+ try:
200
+ img_bytes = base64.b64decode(payload.image)
201
+ img = Image.open(io.BytesIO(img_bytes))
202
+ except Exception as e:
203
+ raise HTTPException(400, f"could not decode image: {e}")
204
+ try:
205
+ return JSONResponse(classify_image(img))
206
+ except Exception as e:
207
+ raise HTTPException(500, f"inference error: {e}")
208
+
209
+
210
+ @app.get("/", response_class=HTMLResponse)
211
+ def index():
212
+ return """
213
+ <!DOCTYPE html>
214
+ <html lang="ko">
215
+ <head>
216
+ <meta charset="utf-8">
217
+ <title>Korean Pest Detector</title>
218
+ <style>
219
+ body { font-family: -apple-system, system-ui, sans-serif; max-width: 640px; margin: 2rem auto; padding: 0 1rem; }
220
+ h1 { font-size: 1.4rem; }
221
+ .drop { border: 2px dashed #aaa; border-radius: 12px; padding: 2rem; text-align: center; cursor: pointer; }
222
+ .drop:hover { background: #f5f5f5; }
223
+ pre { background: #f5f5f5; padding: 1rem; border-radius: 8px; overflow-x: auto; }
224
+ img { max-width: 100%; border-radius: 8px; margin-top: 1rem; }
225
+ .pred { font-size: 1.6rem; font-weight: bold; color: #2a6b3a; }
226
+ .err { color: #b00; }
227
+ </style>
228
+ </head>
229
+ <body>
230
+ <h1>🌾 Korean Pest Detector</h1>
231
+ <p>Qwen3.5-9B + LoRA (Unsloth + PEFT runtime). 19개 클래슀, ν•œκ΅­μ–΄ 좜λ ₯.</p>
232
+ <input id="f" type="file" accept="image/*">
233
+ <div id="result"></div>
234
+ <script>
235
+ document.getElementById('f').onchange = async (e) => {
236
+ const file = e.target.files[0];
237
+ if (!file) return;
238
+ const r = document.getElementById('result');
239
+ r.innerHTML = '<p>뢄석 쀑...</p>';
240
+ const fd = new FormData();
241
+ fd.append('file', file);
242
+ const t0 = performance.now();
243
+ try {
244
+ const resp = await fetch('/classify', {method: 'POST', body: fd});
245
+ const j = await resp.json();
246
+ if (!resp.ok) throw new Error(j.detail || 'error');
247
+ const elapsed = ((performance.now() - t0) / 1000).toFixed(2);
248
+ const url = URL.createObjectURL(file);
249
+ r.innerHTML = `<p class="pred">${j.pred}</p>
250
+ <p>raw: <code>${j.raw}</code> Β· μΆ”λ‘  ${j.elapsed_s}s Β· 총 ${elapsed}s</p>
251
+ <img src="${url}">`;
252
+ } catch (err) {
253
+ r.innerHTML = '<p class="err">'+err.message+'</p>';
254
+ }
255
+ };
256
+ </script>
257
+ </body>
258
+ </html>
259
+ """
260
+
261
+
262
+ if __name__ == "__main__":
263
+ port = int(os.environ.get("PORT", "8080"))
264
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")