| |
| """FastAPI server for the Korean pest detector. |
| |
| Wraps the validated Unsloth FastVisionModel + PEFT runtime LoRA setup |
| (load_in_4bit=True by default β ~8.7 GB VRAM). |
| |
| Endpoints: |
| GET /health β {"status": "ok", "model_loaded": bool} |
| GET /classes β ["κ²κ±°μΈλ―Έλ°€λλ°©", ...] (19 classes) |
| GET / β minimal HTML upload form |
| POST /classify β multipart file OR JSON {"image": "<base64>"} |
| returns {"pred": ..., "raw": ..., "elapsed_s": ..., "all_classes": [...]} |
| |
| Env: |
| BASE_MODEL default: unsloth/Qwen3.5-9B |
| ADAPTER default: pfox1995/pest-detector-deploy |
| LOAD_IN_4BIT "true"/"false" (default: true) |
| PORT default: 8080 |
| |
| Usage: |
| python server.py |
| """ |
| import base64 |
| import io |
| import os |
| import time |
| from contextlib import asynccontextmanager |
| from typing import Optional |
|
|
| import torch |
| import uvicorn |
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.responses import HTMLResponse, JSONResponse |
| from PIL import Image |
| from pydantic import BaseModel |
|
|
| |
| PEST_CLASSES = [ |
| "κ²κ±°μΈλ―Έλ°€λλ°©", "κ½λ
Έλμ΄μ±λ²λ ", "λ΄λ°°κ°λ£¨μ΄", "λ΄λ°°κ±°μΈλ―Έλλ°©", |
| "λ΄λ°°λλ°©", "λλλλ°©", "λ¨Ήλ
Έλ¦°μ¬", "λͺ©νλ°λλͺ
λλ°©", "무μλ²", |
| "λ°°μΆμ’λλ°©", "λ°°μΆν°λλΉ", "벼룩μλ²λ ", "λΉλ¨λ
Έλ¦°μ¬", "μ©λ©λ무λ
Έλ¦°μ¬", |
| "μλ½μμΌλ
Έλ¦°μ¬", "μ μ", "ν°28μ λ°μ΄λ¬΄λΉλ²λ ", "ν±λ€λ¦¬κ°λ―Έν리λ
Έλ¦°μ¬", |
| "νλ°€λλ°©", |
| ] |
| SYSTEM_MSG = ( |
| "λΉμ μ μλ¬Ό ν΄μΆ© μλ³ μ λ¬Έκ°μ
λλ€. " |
| "μ¬μ§μ λ³΄κ³ ν΄μΆ©μ μ΄λ¦λ§ νκ΅μ΄λ‘ λ΅νμΈμ. " |
| 'ν΄μΆ©μ΄ μμΌλ©΄ "μ μ"μ΄λΌκ³ λ§ λ΅νμΈμ. ' |
| "λΆκ° μ€λͺ
μμ΄ μ΄λ¦λ§ μΆλ ₯νμΈμ." |
| ) |
| USER_PROMPT = "μ΄ μ¬μ§μ μλ ν΄μΆ©μ μ΄λ¦μ μλ €μ£ΌμΈμ." |
| LETTERBOX_SIZE = 512 |
| LETTERBOX_FILL = (128, 128, 128) |
|
|
|
|
| def letterbox(img: Image.Image, size: int = LETTERBOX_SIZE) -> Image.Image: |
| img = img.convert("RGB") |
| w, h = img.size |
| scale = size / max(w, h) |
| nw, nh = int(round(w * scale)), int(round(h * scale)) |
| resized = img.resize((nw, nh), Image.Resampling.LANCZOS) |
| canvas = Image.new("RGB", (size, size), LETTERBOX_FILL) |
| canvas.paste(resized, ((size - nw) // 2, (size - nh) // 2)) |
| return canvas |
|
|
|
|
| |
| class ModelState: |
| model = None |
| tokenizer = None |
| text_tokenizer = None |
|
|
|
|
| STATE = ModelState() |
|
|
|
|
| def load_model(): |
| from unsloth import FastVisionModel |
| from peft import PeftModel |
| from huggingface_hub import snapshot_download |
|
|
| base = os.environ.get("BASE_MODEL", "unsloth/Qwen3.5-9B") |
| adapter = os.environ.get("ADAPTER", "pfox1995/pest-detector-deploy") |
| four_bit = os.environ.get("LOAD_IN_4BIT", "true").lower() == "true" |
|
|
| if os.environ.get("HF_TOKEN"): |
| from huggingface_hub import login |
| login(token=os.environ["HF_TOKEN"], add_to_git_credential=False) |
|
|
| print(f"[startup] FastVisionModel.from_pretrained({base}, load_in_4bit={four_bit})", flush=True) |
| t0 = time.time() |
| model, tok = FastVisionModel.from_pretrained(base, load_in_4bit=four_bit) |
| print(f"[startup] loaded base in {time.time()-t0:.1f}s; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True) |
|
|
| adapter_dir = adapter if os.path.isdir(adapter) else snapshot_download(repo_id=adapter) |
| print(f"[startup] attaching LoRA: {adapter_dir}", flush=True) |
| model = PeftModel.from_pretrained(model, adapter_dir) |
| FastVisionModel.for_inference(model) |
| model.eval() |
| print(f"[startup] ready; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True) |
|
|
| STATE.model = model |
| STATE.tokenizer = tok |
| STATE.text_tokenizer = tok.tokenizer if hasattr(tok, "tokenizer") else tok |
|
|
|
|
| def classify_image(img: Image.Image) -> dict: |
| if STATE.model is None: |
| raise RuntimeError("Model not loaded") |
| image = letterbox(img) |
| messages = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]}, |
| {"role": "user", "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": USER_PROMPT}, |
| ]}, |
| ] |
| text = STATE.tokenizer.apply_chat_template( |
| messages, add_generation_prompt=True, enable_thinking=False, |
| ) |
| inputs = STATE.tokenizer( |
| image, text, add_special_tokens=False, return_tensors="pt", |
| ).to("cuda") |
|
|
| t0 = time.time() |
| with torch.inference_mode(): |
| out = STATE.model.generate( |
| **inputs, |
| max_new_tokens=10, |
| use_cache=True, |
| stop_strings=["\n"], |
| tokenizer=STATE.text_tokenizer, |
| ) |
| elapsed = time.time() - t0 |
| raw = STATE.tokenizer.decode( |
| out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, |
| ).strip() |
| pred = raw if raw in PEST_CLASSES else None |
| if pred is None: |
| for c in sorted(PEST_CLASSES, key=len, reverse=True): |
| if raw.startswith(c): |
| pred = c |
| break |
| if pred is None: |
| pred = raw |
| return {"pred": pred, "raw": raw, "elapsed_s": round(elapsed, 3)} |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| load_model() |
| yield |
| |
|
|
|
|
| app = FastAPI( |
| title="Korean Pest Detector", |
| description="Qwen3.5-9B + LoRA via Unsloth + PEFT runtime", |
| lifespan=lifespan, |
| ) |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "model_loaded": STATE.model is not None} |
|
|
|
|
| @app.get("/classes") |
| def classes(): |
| return {"classes": PEST_CLASSES, "count": len(PEST_CLASSES)} |
|
|
|
|
| class ClassifyJSON(BaseModel): |
| image: str |
|
|
|
|
| @app.post("/classify") |
| async def classify( |
| file: Optional[UploadFile] = File(None), |
| ): |
| """Accepts multipart 'file' upload.""" |
| if file is None: |
| raise HTTPException(400, "Provide 'file' multipart field, or POST JSON to /classify_b64") |
| try: |
| img_bytes = await file.read() |
| img = Image.open(io.BytesIO(img_bytes)) |
| except Exception as e: |
| raise HTTPException(400, f"could not parse image: {e}") |
| try: |
| return JSONResponse(classify_image(img)) |
| except Exception as e: |
| raise HTTPException(500, f"inference error: {e}") |
|
|
|
|
| @app.post("/classify_b64") |
| async def classify_b64(payload: ClassifyJSON): |
| """Accepts JSON {"image": "<base64-encoded image>"}.""" |
| try: |
| img_bytes = base64.b64decode(payload.image) |
| img = Image.open(io.BytesIO(img_bytes)) |
| except Exception as e: |
| raise HTTPException(400, f"could not decode image: {e}") |
| try: |
| return JSONResponse(classify_image(img)) |
| except Exception as e: |
| raise HTTPException(500, f"inference error: {e}") |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| def index(): |
| return """ |
| <!DOCTYPE html> |
| <html lang="ko"> |
| <head> |
| <meta charset="utf-8"> |
| <title>Korean Pest Detector</title> |
| <style> |
| body { font-family: -apple-system, system-ui, sans-serif; max-width: 640px; margin: 2rem auto; padding: 0 1rem; } |
| h1 { font-size: 1.4rem; } |
| .drop { border: 2px dashed #aaa; border-radius: 12px; padding: 2rem; text-align: center; cursor: pointer; } |
| .drop:hover { background: #f5f5f5; } |
| pre { background: #f5f5f5; padding: 1rem; border-radius: 8px; overflow-x: auto; } |
| img { max-width: 100%; border-radius: 8px; margin-top: 1rem; } |
| .pred { font-size: 1.6rem; font-weight: bold; color: #2a6b3a; } |
| .err { color: #b00; } |
| </style> |
| </head> |
| <body> |
| <h1>πΎ Korean Pest Detector</h1> |
| <p>Qwen3.5-9B + LoRA (Unsloth + PEFT runtime). 19κ° ν΄λμ€, νκ΅μ΄ μΆλ ₯.</p> |
| <input id="f" type="file" accept="image/*"> |
| <div id="result"></div> |
| <script> |
| document.getElementById('f').onchange = async (e) => { |
| const file = e.target.files[0]; |
| if (!file) return; |
| const r = document.getElementById('result'); |
| r.innerHTML = '<p>λΆμ μ€...</p>'; |
| const fd = new FormData(); |
| fd.append('file', file); |
| const t0 = performance.now(); |
| try { |
| const resp = await fetch('/classify', {method: 'POST', body: fd}); |
| const j = await resp.json(); |
| if (!resp.ok) throw new Error(j.detail || 'error'); |
| const elapsed = ((performance.now() - t0) / 1000).toFixed(2); |
| const url = URL.createObjectURL(file); |
| r.innerHTML = `<p class="pred">${j.pred}</p> |
| <p>raw: <code>${j.raw}</code> Β· μΆλ‘ ${j.elapsed_s}s Β· μ΄ ${elapsed}s</p> |
| <img src="${url}">`; |
| } catch (err) { |
| r.innerHTML = '<p class="err">'+err.message+'</p>'; |
| } |
| }; |
| </script> |
| </body> |
| </html> |
| """ |
|
|
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", "8080")) |
| uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |
|
|