File size: 1,844 Bytes
9229b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from prompts import make_user_query, system_prompt

from transformers import (
    Qwen3_5ForConditionalGeneration, 
    AutoProcessor, 
)

from PIL import Image
import torch

MODEL_PATH = "M:/ai/qwen3.5_mm_trainer/Qwen3.5-4B-Base_k2"

DEVICE = 'cuda'
model = Qwen3_5ForConditionalGeneration.from_pretrained(
            MODEL_PATH,
            torch_dtype=torch.bfloat16,
            attn_implementation="sdpa", 
            device_map=DEVICE
        )

processor = AutoProcessor.from_pretrained(
        MODEL_PATH, 
        min_pixels=256*32*32,
        padding_side="right"
    )

C_TYPE = 'long_thoughts_v2'
USE_NAMES = True
ADD_TAGS = False
ADD_CHAR_LIST = False
ADD_CHARS_TAGS = False
ADD_CHARS_DESCR = False

def prepare_messages(item):
    user_query = make_user_query(item,
         C_TYPE, USE_NAMES, ADD_TAGS, ADD_CHAR_LIST, ADD_CHARS_TAGS, ADD_CHARS_DESCR
                 )
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_prompt}]
        },
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": user_query},
            ],
        }
    ]

img = Image.open('test_image.png')
images = [img]
msgs = prepare_messages({})
texts = [processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)]
inputs = processor(text=texts, images=images, return_tensors="pt")
inputs = {k:v.to(DEVICE) for k,v in inputs.items()}
with torch.no_grad():
    generate_ids = model.generate(**inputs, max_new_tokens=1024)
generated_texts = processor.batch_decode(
                generate_ids[:, inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            )
            
print(generated_texts[0])