#!/usr/bin/env python3 """FastAPI server for the Korean pest detector. Wraps the validated Unsloth FastVisionModel + PEFT runtime LoRA setup (load_in_4bit=True by default → ~8.7 GB VRAM). Endpoints: GET /health → {"status": "ok", "model_loaded": bool} GET /classes → ["검거세미밤나방", ...] (19 classes) GET / → minimal HTML upload form POST /classify → multipart file OR JSON {"image": ""} returns {"pred": ..., "raw": ..., "elapsed_s": ..., "all_classes": [...]} Env: BASE_MODEL default: unsloth/Qwen3.5-9B ADAPTER default: pfox1995/pest-detector-deploy LOAD_IN_4BIT "true"/"false" (default: true) PORT default: 8080 Usage: python server.py """ import base64 import io import os import time from contextlib import asynccontextmanager from typing import Optional import torch import uvicorn from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import HTMLResponse, JSONResponse from PIL import Image from pydantic import BaseModel # ─── Constants from training (DO NOT change) ───────────────────────────── PEST_CLASSES = [ "검거세미밤나방", "꽃노랑총채벌레", "담배가루이", "담배거세미나방", "담배나방", "도둑나방", "먹노린재", "목화바둑명나방", "무잎벌", "배추좀나방", "배추흰나비", "벼룩잎벌레", "비단노린재", "썩덩나무노린재", "알락수염노린재", "정상", "큰28점박이무당벌레", "톱다리개미허리노린재", "파밤나방", ] SYSTEM_MSG = ( "당신은 작물 해충 식별 전문가입니다. " "사진을 보고 해충의 이름만 한국어로 답하세요. " '해충이 없으면 "정상"이라고만 답하세요. ' "부가 설명 없이 이름만 출력하세요." ) USER_PROMPT = "이 사진에 있는 해충의 이름을 알려주세요." LETTERBOX_SIZE = 512 LETTERBOX_FILL = (128, 128, 128) def letterbox(img: Image.Image, size: int = LETTERBOX_SIZE) -> Image.Image: img = img.convert("RGB") w, h = img.size scale = size / max(w, h) nw, nh = int(round(w * scale)), int(round(h * scale)) resized = img.resize((nw, nh), Image.Resampling.LANCZOS) canvas = Image.new("RGB", (size, size), LETTERBOX_FILL) canvas.paste(resized, ((size - nw) // 2, (size - nh) // 2)) return canvas # ─── Model state ───────────────────────────────────────────────────────── class ModelState: model = None tokenizer = None text_tokenizer = None # underlying transformers tokenizer (for stop_strings=) STATE = ModelState() def load_model(): from unsloth import FastVisionModel from peft import PeftModel from huggingface_hub import snapshot_download base = os.environ.get("BASE_MODEL", "unsloth/Qwen3.5-9B") adapter = os.environ.get("ADAPTER", "pfox1995/pest-detector-deploy") four_bit = os.environ.get("LOAD_IN_4BIT", "true").lower() == "true" if os.environ.get("HF_TOKEN"): from huggingface_hub import login login(token=os.environ["HF_TOKEN"], add_to_git_credential=False) print(f"[startup] FastVisionModel.from_pretrained({base}, load_in_4bit={four_bit})", flush=True) t0 = time.time() model, tok = FastVisionModel.from_pretrained(base, load_in_4bit=four_bit) print(f"[startup] loaded base in {time.time()-t0:.1f}s; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True) adapter_dir = adapter if os.path.isdir(adapter) else snapshot_download(repo_id=adapter) print(f"[startup] attaching LoRA: {adapter_dir}", flush=True) model = PeftModel.from_pretrained(model, adapter_dir) FastVisionModel.for_inference(model) model.eval() print(f"[startup] ready; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True) STATE.model = model STATE.tokenizer = tok STATE.text_tokenizer = tok.tokenizer if hasattr(tok, "tokenizer") else tok def classify_image(img: Image.Image) -> dict: if STATE.model is None: raise RuntimeError("Model not loaded") image = letterbox(img) messages = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]}, {"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": USER_PROMPT}, ]}, ] text = STATE.tokenizer.apply_chat_template( messages, add_generation_prompt=True, enable_thinking=False, ) inputs = STATE.tokenizer( image, text, add_special_tokens=False, return_tensors="pt", ).to("cuda") t0 = time.time() with torch.inference_mode(): out = STATE.model.generate( **inputs, max_new_tokens=10, use_cache=True, stop_strings=["\n"], tokenizer=STATE.text_tokenizer, ) elapsed = time.time() - t0 raw = STATE.tokenizer.decode( out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ).strip() pred = raw if raw in PEST_CLASSES else None if pred is None: for c in sorted(PEST_CLASSES, key=len, reverse=True): if raw.startswith(c): pred = c break if pred is None: pred = raw # surface raw text if no class match (debugging signal) return {"pred": pred, "raw": raw, "elapsed_s": round(elapsed, 3)} # ─── FastAPI app ───────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): load_model() yield # nothing to clean up app = FastAPI( title="Korean Pest Detector", description="Qwen3.5-9B + LoRA via Unsloth + PEFT runtime", lifespan=lifespan, ) @app.get("/health") def health(): return {"status": "ok", "model_loaded": STATE.model is not None} @app.get("/classes") def classes(): return {"classes": PEST_CLASSES, "count": len(PEST_CLASSES)} class ClassifyJSON(BaseModel): image: str # base64-encoded image bytes @app.post("/classify") async def classify( file: Optional[UploadFile] = File(None), ): """Accepts multipart 'file' upload.""" if file is None: raise HTTPException(400, "Provide 'file' multipart field, or POST JSON to /classify_b64") try: img_bytes = await file.read() img = Image.open(io.BytesIO(img_bytes)) except Exception as e: raise HTTPException(400, f"could not parse image: {e}") try: return JSONResponse(classify_image(img)) except Exception as e: raise HTTPException(500, f"inference error: {e}") @app.post("/classify_b64") async def classify_b64(payload: ClassifyJSON): """Accepts JSON {"image": ""}.""" try: img_bytes = base64.b64decode(payload.image) img = Image.open(io.BytesIO(img_bytes)) except Exception as e: raise HTTPException(400, f"could not decode image: {e}") try: return JSONResponse(classify_image(img)) except Exception as e: raise HTTPException(500, f"inference error: {e}") @app.get("/", response_class=HTMLResponse) def index(): return """ Korean Pest Detector

🌾 Korean Pest Detector

Qwen3.5-9B + LoRA (Unsloth + PEFT runtime). 19개 클래스, 한국어 출력.

""" if __name__ == "__main__": port = int(os.environ.get("PORT", "8080")) uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")