#!/usr/bin/env python3 """Reference inference script for the Korean pest detector LoRA. This is the validated deployment recipe — every gotcha we hit during the export rabbit-hole is encoded here. See README.md for the full list. Usage: # Single image python inference.py path/to/image.jpg # Glob (--bench compares against parent dir as ground truth) python inference.py 'val/*/*.jpg' --bench # 4-bit (8.7 GB VRAM) python inference.py path/to/image.jpg --4bit """ import argparse import glob import os import sys import time from pathlib import Path import torch from PIL import Image # ─── Constants from training ───────────────────────────────────────────── PEST_CLASSES = [ "검거세미밤나방", "꽃노랑총채벌레", "담배가루이", "담배거세미나방", "담배나방", "도둑나방", "먹노린재", "목화바둑명나방", "무잎벌", "배추좀나방", "배추흰나비", "벼룩잎벌레", "비단노린재", "썩덩나무노린재", "알락수염노린재", "정상", "큰28점박이무당벌레", "톱다리개미허리노린재", "파밤나방", ] SYSTEM_MSG = ( "당신은 작물 해충 식별 전문가입니다. " "사진을 보고 해충의 이름만 한국어로 답하세요. " '해충이 없으면 "정상"이라고만 답하세요. ' "부가 설명 없이 이름만 출력하세요." ) USER_PROMPT = "이 사진에 있는 해충의 이름을 알려주세요." LETTERBOX_SIZE = 512 LETTERBOX_FILL = (128, 128, 128) # ─── Image preprocessing (matches training) ────────────────────────────── def letterbox(img: Image.Image, size: int = LETTERBOX_SIZE) -> Image.Image: img = img.convert("RGB") w, h = img.size scale = size / max(w, h) nw, nh = int(round(w * scale)), int(round(h * scale)) resized = img.resize((nw, nh), Image.Resampling.LANCZOS) canvas = Image.new("RGB", (size, size), LETTERBOX_FILL) canvas.paste(resized, ((size - nw) // 2, (size - nh) // 2)) return canvas # ─── Model loading (the working setup) ─────────────────────────────────── def load_model(base_id: str, adapter: str, four_bit: bool = False): """Returns (model, tokenizer) ready for inference. CRITICAL: uses unsloth.FastVisionModel + peft.PeftModel.from_pretrained runtime hooks. Do NOT call merge_and_unload — it silently corrupts the linear_attn LoRA delta in this architecture. """ from unsloth import FastVisionModel from peft import PeftModel from huggingface_hub import snapshot_download print(f"Loading base via FastVisionModel: {base_id} (load_in_4bit={four_bit})", flush=True) t0 = time.time() model, tokenizer = FastVisionModel.from_pretrained(base_id, load_in_4bit=four_bit) print(f" loaded in {time.time()-t0:.1f}s; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True) adapter_dir = adapter if os.path.isdir(adapter) else snapshot_download(repo_id=adapter) print(f"Attaching LoRA: {adapter_dir}", flush=True) model = PeftModel.from_pretrained(model, adapter_dir) # Required — flips internal mode. Without it generation drifts to 'adge' attractor. FastVisionModel.for_inference(model) model.eval() print(f" ready; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True) return model, tokenizer # ─── Single-image classification ───────────────────────────────────────── def classify(model, tokenizer, img: Image.Image) -> dict: image = letterbox(img, LETTERBOX_SIZE) messages = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]}, {"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": USER_PROMPT}, ]}, ] # enable_thinking=False as DIRECT kwarg (NOT chat_template_kwargs={...}, # which is silently ignored by VLM processors). text = tokenizer.apply_chat_template( messages, add_generation_prompt=True, enable_thinking=False, ) inputs = tokenizer( image, text, add_special_tokens=False, return_tensors="pt", ).to("cuda") t0 = time.time() with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=10, # NOT 16+ — over-running emits 'adge' use_cache=True, stop_strings=["\n"], # natural training-time stop tokenizer=tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer, ) elapsed = time.time() - t0 raw = tokenizer.decode( out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ).strip() # Direct equality typically works; fall back to startswith for robustness. pred = raw if raw not in PEST_CLASSES: for c in sorted(PEST_CLASSES, key=len, reverse=True): if raw.startswith(c): pred = c break return {"pred": pred, "raw": raw, "elapsed_s": elapsed} # ─── CLI ───────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser() ap.add_argument("paths", nargs="+", help="Image file(s) or glob(s).") ap.add_argument("--base", default="unsloth/Qwen3.5-9B") ap.add_argument("--adapter", default="pfox1995/pest-detector-final") ap.add_argument("--bench", action="store_true", help="Treat parent dir name as truth; print accuracy.") ap.add_argument("--4bit", dest="four_bit", action="store_true", help="Load base in bnb NF4 (~8.7 GB VRAM, no accuracy loss for this task).") args = ap.parse_args() files: list[str] = [] for p in args.paths: files.extend(sorted(glob.glob(p)) if any(c in p for c in "*?[") else [p]) files = [f for f in files if os.path.isfile(f)] if not files: sys.exit("no input files") model, tokenizer = load_model(args.base, args.adapter, args.four_bit) correct = 0 per_class: dict[str, list[int]] = {} for f in files: truth = Path(f).parent.name if args.bench else None with Image.open(f) as raw: out = classify(model, tokenizer, raw) ok = (truth and out["pred"] == truth) marker = ("✓" if ok else "✗") if truth else " " if truth: per_class.setdefault(truth, [0, 0]) per_class[truth][0] += int(ok) per_class[truth][1] += 1 correct += int(ok) print(f"{marker} pred={out['pred']:<20s} ({out['elapsed_s']:.1f}s)" f"{' truth=' + truth if truth else ''} [{Path(f).name}]") if args.bench and per_class: total = sum(t for _, t in per_class.values()) print(f"\n=== ACCURACY: {correct}/{total} = {100*correct/total:.1f}% ===") for cls, (c, t) in sorted(per_class.items(), key=lambda x: -x[1][0]/max(1, x[1][1])): print(f" {c}/{t} {100*c/t:5.1f}% {cls}") if __name__ == "__main__": main()