File size: 3,871 Bytes
df6b3ac
26dae50
df6b3ac
 
 
 
26dae50
 
 
 
df6b3ac
 
 
26dae50
 
 
 
df6b3ac
 
 
 
26dae50
df6b3ac
 
 
 
 
 
 
 
 
26dae50
df6b3ac
 
 
 
 
26dae50
 
 
 
 
 
 
 
 
 
df6b3ac
 
 
 
 
 
 
26dae50
 
 
 
 
 
 
 
df6b3ac
26dae50
df6b3ac
26dae50
 
 
 
 
 
 
df6b3ac
 
26dae50
df6b3ac
 
 
 
26dae50
 
 
df6b3ac
26dae50
 
 
 
 
df6b3ac
 
 
 
 
 
 
 
 
26dae50
df6b3ac
 
26dae50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Vision-language: describe / answer about an image, in PT/EN.

Generic loader (AutoModelForImageTextToText) — supports:
  - Qwen/Qwen3-VL-2B-Instruct (default — light, fast, strong OCR)
  - Qwen/Qwen2.5-VL-3B-Instruct, openbmb/MiniCPM-V-4.6, etc.
Swappable via IRIS_VLM_MODEL. The VLM IS the text generator for speech.
"""
import os

_model = None
_aux = None
MODEL_ID = os.environ.get("IRIS_VLM_MODEL", "Qwen/Qwen3-VL-2B-Instruct")
DOWNSAMPLE = os.environ.get("IRIS_DOWNSAMPLE", "4x")  # MiniCPM only (detail for OCR)

SYSTEM_PT = (
    "Você é os olhos de uma pessoa cega. RESPONDA OBRIGATORIAMENTE EM PORTUGUÊS "
    "DO BRASIL, em no máximo duas frases curtas, dizendo só o que é relevante e "
    "útil sobre a cena. Não comece com 'a imagem mostra'. "
    "Se houver texto importante (rótulo, placa, remédio), leia-o exatamente como está. "
    "Se houver DINHEIRO (cédulas ou moedas de real), identifique cada valor e diga o TOTAL. "
    "Se for uma CONTA, boleto ou documento, leia o VALOR TOTAL e a DATA DE VENCIMENTO."
)
SYSTEM_EN = (
    "You are the eyes of a blind person. ALWAYS REPLY IN ENGLISH, in at most two "
    "short sentences, saying only what is relevant and useful about the scene. Do "
    "not start with 'the image shows'. "
    "If there is important text (label, sign, medicine), read it exactly as written. "
    "If there is MONEY (banknotes or coins), identify each value and state the TOTAL. "
    "If it is a BILL or document, read the TOTAL AMOUNT and the DUE DATE."
)


def _prompt(lang):
    """Return (system_prompt, default_question) for the language."""
    if lang == "en":
        return SYSTEM_EN, "What is in front of me?"
    return SYSTEM_PT, "O que há à minha frente?"


def _family() -> str:
    return "minicpm" if "minicpm" in MODEL_ID.lower() else "qwen"


def _load():
    global _model, _aux
    if _model is None:
        import torch
        from transformers import AutoModelForImageTextToText, AutoProcessor
        kw = {"trust_remote_code": True} if _family() == "minicpm" else {}
        _model = AutoModelForImageTextToText.from_pretrained(
            MODEL_ID, torch_dtype=torch.float16, device_map="cuda:0",
            low_cpu_mem_usage=True, **kw,
        ).eval()
        _aux = AutoProcessor.from_pretrained(MODEL_ID, **kw)
    return _model, _aux


def _to_pil(image):
    from PIL import Image
    if isinstance(image, str):
        image = Image.open(image)
    elif not isinstance(image, Image.Image):
        image = Image.fromarray(image)  # numpy frame from the webcam
    image = image.convert("RGB")
    image.thumbnail((1024, 1024))  # fewer vision tokens -> faster; still good OCR
    return image


from .gpu import gpu


@gpu(duration=60)
def describe(image, question: str = "", lang: str = "pt", system: str = None) -> str:
    import torch
    image = _to_pil(image)
    sys_prompt, default_q = _prompt(lang)
    if system:
        sys_prompt = system
    user = (question or "").strip() or default_q
    model, aux = _load()

    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": user},
        ]},
    ]
    tmpl_kw, gen_kw = {}, {}
    if _family() == "minicpm":
        tmpl_kw = {"downsample_mode": DOWNSAMPLE, "max_slice_nums": 36}
        gen_kw = {"downsample_mode": DOWNSAMPLE}

    inputs = aux.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=True,
        return_dict=True, return_tensors="pt", **tmpl_kw,
    ).to(model.device)
    with torch.no_grad():
        generated = model.generate(**inputs, max_new_tokens=96, do_sample=False, **gen_kw)
    trimmed = [o[len(i):] for i, o in zip(inputs.input_ids, generated)]
    return aux.batch_decode(trimmed, skip_special_tokens=True)[0].strip()