pest-detector-deploy / inference.py
pfox1995's picture
Initial deploy-ready bundle: adapter + README + inference.py + requirements.txt
96df5b9 verified
#!/usr/bin/env python3
"""Reference inference script for the Korean pest detector LoRA.
This is the validated deployment recipe β€” every gotcha we hit during
the export rabbit-hole is encoded here. See README.md for the full list.
Usage:
# Single image
python inference.py path/to/image.jpg
# Glob (--bench compares against parent dir as ground truth)
python inference.py 'val/*/*.jpg' --bench
# 4-bit (8.7 GB VRAM)
python inference.py path/to/image.jpg --4bit
"""
import argparse
import glob
import os
import sys
import time
from pathlib import Path
import torch
from PIL import Image
# ─── Constants from training ─────────────────────────────────────────────
PEST_CLASSES = [
"κ²€κ±°μ„Έλ―Έλ°€λ‚˜λ°©", "κ½ƒλ…Έλž‘μ΄μ±„λ²Œλ ˆ", "담배가루이", "λ‹΄λ°°κ±°μ„Έλ―Έλ‚˜λ°©",
"λ‹΄λ°°λ‚˜λ°©", "λ„λ‘‘λ‚˜λ°©", "λ¨Ήλ…Έλ¦°μž¬", "λͺ©ν™”λ°”λ‘‘λͺ…λ‚˜λ°©", "무잎벌",
"λ°°μΆ”μ’€λ‚˜λ°©", "λ°°μΆ”ν°λ‚˜λΉ„", "벼룩잎벌레", "λΉ„λ‹¨λ…Έλ¦°μž¬", "μ©λ©λ‚˜λ¬΄λ…Έλ¦°μž¬",
"μ•Œλ½μˆ˜μ—Όλ…Έλ¦°μž¬", "정상", "큰28μ λ°•μ΄λ¬΄λ‹Ήλ²Œλ ˆ", "ν†±λ‹€λ¦¬κ°œλ―Έν—ˆλ¦¬λ…Έλ¦°μž¬",
"νŒŒλ°€λ‚˜λ°©",
]
SYSTEM_MSG = (
"당신은 μž‘λ¬Ό ν•΄μΆ© 식별 μ „λ¬Έκ°€μž…λ‹ˆλ‹€. "
"사진을 보고 ν•΄μΆ©μ˜ μ΄λ¦„λ§Œ ν•œκ΅­μ–΄λ‘œ λ‹΅ν•˜μ„Έμš”. "
'해좩이 μ—†μœΌλ©΄ "정상"이라고만 λ‹΅ν•˜μ„Έμš”. '
"λΆ€κ°€ μ„€λͺ… 없이 μ΄λ¦„λ§Œ 좜λ ₯ν•˜μ„Έμš”."
)
USER_PROMPT = "이 사진에 μžˆλŠ” ν•΄μΆ©μ˜ 이름을 μ•Œλ €μ£Όμ„Έμš”."
LETTERBOX_SIZE = 512
LETTERBOX_FILL = (128, 128, 128)
# ─── Image preprocessing (matches training) ──────────────────────────────
def letterbox(img: Image.Image, size: int = LETTERBOX_SIZE) -> Image.Image:
img = img.convert("RGB")
w, h = img.size
scale = size / max(w, h)
nw, nh = int(round(w * scale)), int(round(h * scale))
resized = img.resize((nw, nh), Image.Resampling.LANCZOS)
canvas = Image.new("RGB", (size, size), LETTERBOX_FILL)
canvas.paste(resized, ((size - nw) // 2, (size - nh) // 2))
return canvas
# ─── Model loading (the working setup) ───────────────────────────────────
def load_model(base_id: str, adapter: str, four_bit: bool = False):
"""Returns (model, tokenizer) ready for inference.
CRITICAL: uses unsloth.FastVisionModel + peft.PeftModel.from_pretrained
runtime hooks. Do NOT call merge_and_unload β€” it silently corrupts
the linear_attn LoRA delta in this architecture.
"""
from unsloth import FastVisionModel
from peft import PeftModel
from huggingface_hub import snapshot_download
print(f"Loading base via FastVisionModel: {base_id} (load_in_4bit={four_bit})", flush=True)
t0 = time.time()
model, tokenizer = FastVisionModel.from_pretrained(base_id, load_in_4bit=four_bit)
print(f" loaded in {time.time()-t0:.1f}s; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)
adapter_dir = adapter if os.path.isdir(adapter) else snapshot_download(repo_id=adapter)
print(f"Attaching LoRA: {adapter_dir}", flush=True)
model = PeftModel.from_pretrained(model, adapter_dir)
# Required β€” flips internal mode. Without it generation drifts to 'adge' attractor.
FastVisionModel.for_inference(model)
model.eval()
print(f" ready; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)
return model, tokenizer
# ─── Single-image classification ─────────────────────────────────────────
def classify(model, tokenizer, img: Image.Image) -> dict:
image = letterbox(img, LETTERBOX_SIZE)
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]},
{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": USER_PROMPT},
]},
]
# enable_thinking=False as DIRECT kwarg (NOT chat_template_kwargs={...},
# which is silently ignored by VLM processors).
text = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, enable_thinking=False,
)
inputs = tokenizer(
image, text, add_special_tokens=False, return_tensors="pt",
).to("cuda")
t0 = time.time()
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=10, # NOT 16+ β€” over-running emits 'adge'
use_cache=True,
stop_strings=["\n"], # natural training-time stop
tokenizer=tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer,
)
elapsed = time.time() - t0
raw = tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True,
).strip()
# Direct equality typically works; fall back to startswith for robustness.
pred = raw
if raw not in PEST_CLASSES:
for c in sorted(PEST_CLASSES, key=len, reverse=True):
if raw.startswith(c):
pred = c
break
return {"pred": pred, "raw": raw, "elapsed_s": elapsed}
# ─── CLI ─────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser()
ap.add_argument("paths", nargs="+", help="Image file(s) or glob(s).")
ap.add_argument("--base", default="unsloth/Qwen3.5-9B")
ap.add_argument("--adapter", default="pfox1995/pest-detector-final")
ap.add_argument("--bench", action="store_true",
help="Treat parent dir name as truth; print accuracy.")
ap.add_argument("--4bit", dest="four_bit", action="store_true",
help="Load base in bnb NF4 (~8.7 GB VRAM, no accuracy loss for this task).")
args = ap.parse_args()
files: list[str] = []
for p in args.paths:
files.extend(sorted(glob.glob(p)) if any(c in p for c in "*?[") else [p])
files = [f for f in files if os.path.isfile(f)]
if not files:
sys.exit("no input files")
model, tokenizer = load_model(args.base, args.adapter, args.four_bit)
correct = 0
per_class: dict[str, list[int]] = {}
for f in files:
truth = Path(f).parent.name if args.bench else None
with Image.open(f) as raw:
out = classify(model, tokenizer, raw)
ok = (truth and out["pred"] == truth)
marker = ("βœ“" if ok else "βœ—") if truth else " "
if truth:
per_class.setdefault(truth, [0, 0])
per_class[truth][0] += int(ok)
per_class[truth][1] += 1
correct += int(ok)
print(f"{marker} pred={out['pred']:<20s} ({out['elapsed_s']:.1f}s)"
f"{' truth=' + truth if truth else ''} [{Path(f).name}]")
if args.bench and per_class:
total = sum(t for _, t in per_class.values())
print(f"\n=== ACCURACY: {correct}/{total} = {100*correct/total:.1f}% ===")
for cls, (c, t) in sorted(per_class.items(), key=lambda x: -x[1][0]/max(1, x[1][1])):
print(f" {c}/{t} {100*c/t:5.1f}% {cls}")
if __name__ == "__main__":
main()