File size: 7,430 Bytes
96df5b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/env python3
"""Reference inference script for the Korean pest detector LoRA.

This is the validated deployment recipe β€” every gotcha we hit during
the export rabbit-hole is encoded here. See README.md for the full list.

Usage:
    # Single image
    python inference.py path/to/image.jpg

    # Glob (--bench compares against parent dir as ground truth)
    python inference.py 'val/*/*.jpg' --bench

    # 4-bit (8.7 GB VRAM)
    python inference.py path/to/image.jpg --4bit
"""
import argparse
import glob
import os
import sys
import time
from pathlib import Path

import torch
from PIL import Image


# ─── Constants from training ─────────────────────────────────────────────
PEST_CLASSES = [
    "κ²€κ±°μ„Έλ―Έλ°€λ‚˜λ°©", "κ½ƒλ…Έλž‘μ΄μ±„λ²Œλ ˆ", "담배가루이", "λ‹΄λ°°κ±°μ„Έλ―Έλ‚˜λ°©",
    "λ‹΄λ°°λ‚˜λ°©", "λ„λ‘‘λ‚˜λ°©", "λ¨Ήλ…Έλ¦°μž¬", "λͺ©ν™”λ°”λ‘‘λͺ…λ‚˜λ°©", "무잎벌",
    "λ°°μΆ”μ’€λ‚˜λ°©", "λ°°μΆ”ν°λ‚˜λΉ„", "벼룩잎벌레", "λΉ„λ‹¨λ…Έλ¦°μž¬", "μ©λ©λ‚˜λ¬΄λ…Έλ¦°μž¬",
    "μ•Œλ½μˆ˜μ—Όλ…Έλ¦°μž¬", "정상", "큰28μ λ°•μ΄λ¬΄λ‹Ήλ²Œλ ˆ", "ν†±λ‹€λ¦¬κ°œλ―Έν—ˆλ¦¬λ…Έλ¦°μž¬",
    "νŒŒλ°€λ‚˜λ°©",
]

SYSTEM_MSG = (
    "당신은 μž‘λ¬Ό ν•΄μΆ© 식별 μ „λ¬Έκ°€μž…λ‹ˆλ‹€. "
    "사진을 보고 ν•΄μΆ©μ˜ μ΄λ¦„λ§Œ ν•œκ΅­μ–΄λ‘œ λ‹΅ν•˜μ„Έμš”. "
    '해좩이 μ—†μœΌλ©΄ "정상"이라고만 λ‹΅ν•˜μ„Έμš”. '
    "λΆ€κ°€ μ„€λͺ… 없이 μ΄λ¦„λ§Œ 좜λ ₯ν•˜μ„Έμš”."
)
USER_PROMPT = "이 사진에 μžˆλŠ” ν•΄μΆ©μ˜ 이름을 μ•Œλ €μ£Όμ„Έμš”."

LETTERBOX_SIZE = 512
LETTERBOX_FILL = (128, 128, 128)


# ─── Image preprocessing (matches training) ──────────────────────────────
def letterbox(img: Image.Image, size: int = LETTERBOX_SIZE) -> Image.Image:
    img = img.convert("RGB")
    w, h = img.size
    scale = size / max(w, h)
    nw, nh = int(round(w * scale)), int(round(h * scale))
    resized = img.resize((nw, nh), Image.Resampling.LANCZOS)
    canvas = Image.new("RGB", (size, size), LETTERBOX_FILL)
    canvas.paste(resized, ((size - nw) // 2, (size - nh) // 2))
    return canvas


# ─── Model loading (the working setup) ───────────────────────────────────
def load_model(base_id: str, adapter: str, four_bit: bool = False):
    """Returns (model, tokenizer) ready for inference.

    CRITICAL: uses unsloth.FastVisionModel + peft.PeftModel.from_pretrained
    runtime hooks. Do NOT call merge_and_unload β€” it silently corrupts
    the linear_attn LoRA delta in this architecture.
    """
    from unsloth import FastVisionModel
    from peft import PeftModel
    from huggingface_hub import snapshot_download

    print(f"Loading base via FastVisionModel: {base_id}  (load_in_4bit={four_bit})", flush=True)
    t0 = time.time()
    model, tokenizer = FastVisionModel.from_pretrained(base_id, load_in_4bit=four_bit)
    print(f"  loaded in {time.time()-t0:.1f}s; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)

    adapter_dir = adapter if os.path.isdir(adapter) else snapshot_download(repo_id=adapter)
    print(f"Attaching LoRA: {adapter_dir}", flush=True)
    model = PeftModel.from_pretrained(model, adapter_dir)

    # Required β€” flips internal mode. Without it generation drifts to 'adge' attractor.
    FastVisionModel.for_inference(model)
    model.eval()
    print(f"  ready; vram={torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)

    return model, tokenizer


# ─── Single-image classification ─────────────────────────────────────────
def classify(model, tokenizer, img: Image.Image) -> dict:
    image = letterbox(img, LETTERBOX_SIZE)
    messages = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]},
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text",  "text": USER_PROMPT},
        ]},
    ]
    # enable_thinking=False as DIRECT kwarg (NOT chat_template_kwargs={...},
    # which is silently ignored by VLM processors).
    text = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, enable_thinking=False,
    )
    inputs = tokenizer(
        image, text, add_special_tokens=False, return_tensors="pt",
    ).to("cuda")

    t0 = time.time()
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=10,             # NOT 16+ β€” over-running emits 'adge'
            use_cache=True,
            stop_strings=["\n"],           # natural training-time stop
            tokenizer=tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer,
        )
    elapsed = time.time() - t0

    raw = tokenizer.decode(
        out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True,
    ).strip()

    # Direct equality typically works; fall back to startswith for robustness.
    pred = raw
    if raw not in PEST_CLASSES:
        for c in sorted(PEST_CLASSES, key=len, reverse=True):
            if raw.startswith(c):
                pred = c
                break

    return {"pred": pred, "raw": raw, "elapsed_s": elapsed}


# ─── CLI ─────────────────────────────────────────────────────────────────
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("paths", nargs="+", help="Image file(s) or glob(s).")
    ap.add_argument("--base", default="unsloth/Qwen3.5-9B")
    ap.add_argument("--adapter", default="pfox1995/pest-detector-final")
    ap.add_argument("--bench", action="store_true",
                    help="Treat parent dir name as truth; print accuracy.")
    ap.add_argument("--4bit", dest="four_bit", action="store_true",
                    help="Load base in bnb NF4 (~8.7 GB VRAM, no accuracy loss for this task).")
    args = ap.parse_args()

    files: list[str] = []
    for p in args.paths:
        files.extend(sorted(glob.glob(p)) if any(c in p for c in "*?[") else [p])
    files = [f for f in files if os.path.isfile(f)]
    if not files:
        sys.exit("no input files")

    model, tokenizer = load_model(args.base, args.adapter, args.four_bit)

    correct = 0
    per_class: dict[str, list[int]] = {}
    for f in files:
        truth = Path(f).parent.name if args.bench else None
        with Image.open(f) as raw:
            out = classify(model, tokenizer, raw)
        ok = (truth and out["pred"] == truth)
        marker = ("βœ“" if ok else "βœ—") if truth else " "
        if truth:
            per_class.setdefault(truth, [0, 0])
            per_class[truth][0] += int(ok)
            per_class[truth][1] += 1
            correct += int(ok)
        print(f"{marker} pred={out['pred']:<20s} ({out['elapsed_s']:.1f}s)"
              f"{' truth=' + truth if truth else ''}  [{Path(f).name}]")

    if args.bench and per_class:
        total = sum(t for _, t in per_class.values())
        print(f"\n=== ACCURACY: {correct}/{total} = {100*correct/total:.1f}% ===")
        for cls, (c, t) in sorted(per_class.items(), key=lambda x: -x[1][0]/max(1, x[1][1])):
            print(f"  {c}/{t}  {100*c/t:5.1f}%  {cls}")


if __name__ == "__main__":
    main()