File size: 8,080 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import argparse
import logging
import os
from typing import Optional

import torch
from PIL import Image
import requests
from io import BytesIO
import yaml

from transformers import GenerationConfig, AutoProcessor

# Allow running without installation when working inside the repo
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from larm.memory_generator.memgen_model import LatentMemoryModel


def _load_image(image_path_or_url: str) -> Image.Image:
    """Load image from local path or URL as RGB PIL.Image."""
    if image_path_or_url.startswith("http"):
        resp = requests.get(image_path_or_url, timeout=30)
        resp.raise_for_status()
        return Image.open(BytesIO(resp.content)).convert("RGB")
    if not os.path.exists(image_path_or_url):
        raise FileNotFoundError(f"Image not found: {image_path_or_url}")
    return Image.open(image_path_or_url).convert("RGB")


def build_inputs(processor, messages, image: Optional[Image.Image] = None):
    """Build model inputs (input_ids, attention_mask, pixel_values, image_grid_thw)."""
    if image is not None:
        # Apply chat template first so that processor knows where to insert image tokens
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        enc = processor(text=[text], images=[image], return_tensors="pt", padding=False)
    else:
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        enc = processor(text=[text], return_tensors="pt", padding=False)

    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]
    pixel_values = enc.get("pixel_values")
    image_grid_thw = enc.get("image_grid_thw")
    if pixel_values is not None:
        pixel_values = pixel_values.to(torch.bfloat16)
    return input_ids, attention_mask, pixel_values, image_grid_thw


def load_model_from_cfg(cfg_path: str, device: torch.device):
    """Load LatentMemoryModel from a YAML config (same structure as training)."""
    with open(cfg_path, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    model_cfg = cfg["model"] if "model" in cfg else cfg  # support wrapped config
    model = LatentMemoryModel.from_config(model_cfg).to(device)
    model.eval()
    return model


def main():
    parser = argparse.ArgumentParser(description="LatentMemoryModel inference script")
    parser.add_argument("--cfg", required=True, help="Path to YAML config used to instantiate the model")
    parser.add_argument("--image", help="Optional image path or URL")
    parser.add_argument("--text", required=True, help="User prompt text")
    parser.add_argument("--max_new_tokens", type=int, default=256)
    parser.add_argument("--temperature", type=float, default=1)
    parser.add_argument("--do_sample", action="store_true", help="Enable sampling")
    parser.add_argument("--options", nargs="*", help="Override model config via KEY VALUE pairs, e.g. --options model.max_prompt_aug_num 0")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    # 1. Load model & processor
    # Handle overrides
    overrides = args.options or []
    if len(overrides) % 2 != 0:
        raise ValueError("--options should contain KEY VALUE pairs")

    if overrides:
        import copy, yaml
        with open(args.cfg, "r", encoding="utf-8") as f:
            base_cfg = yaml.safe_load(f)

        cfg = copy.deepcopy(base_cfg)

        def set_nested(cfg_dict, key_path, value):
            keys = key_path.split('.')
            cur = cfg_dict
            for k in keys[:-1]:
                if k not in cur or not isinstance(cur[k], dict):
                    cur[k] = {}
                cur = cur[k]
            # try to cast value to int/float/bool
            if value.lower() == 'null':
                val_cast = None
            else:
                for cast in (int, float):
                    try:
                        val_cast = cast(value)
                        break
                    except ValueError:
                        val_cast = value
                if value.lower() in ("true", "false"):
                    val_cast = value.lower() == "true"
            cur[keys[-1]] = val_cast

        for k, v in zip(overrides[::2], overrides[1::2]):
            set_nested(cfg, k, v)

        # write to tmp then load
        import tempfile, os
        with tempfile.NamedTemporaryFile("w", delete=False, suffix=".yaml") as tmp:
            yaml.safe_dump(cfg, tmp)
            tmp_path = tmp.name
        model = load_model_from_cfg(tmp_path, device)
        os.remove(tmp_path)
    else:
        model = load_model_from_cfg(args.cfg, device)
    processor = model.processor  # AutoProcessor loaded inside the model

    # 2. Build messages list
    messages = []
    if args.image:
        image = _load_image(args.image)
        messages.append({
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": args.text},
            ],
        })
    else:
        messages.append({
            "role": "user",
            "content": args.text,
        })

    # 3. Tokenize / encode
    input_ids, attention_mask, pixel_values, image_grid_thw = build_inputs(processor, messages, image if args.image else None)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    if pixel_values is not None:
        pixel_values = pixel_values.to(device)
    if image_grid_thw is not None:
        image_grid_thw = image_grid_thw.to(device)

    # 4. Build generation config
    gen_cfg = GenerationConfig(
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        do_sample=args.do_sample,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
    )

    # 5. Generate
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            generation_config=gen_cfg,
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
        )
    if isinstance(outputs, tuple):  # when return_augmentation_mask=True
        outputs = outputs[0]
    full_ids = outputs[0].detach().cpu()
    prompt_len = input_ids.size(1)

    # 5.a Only the assistant completion (clean)
    gen_only = full_ids[prompt_len:]
    gen_only_valid = [tid for tid in gen_only.tolist() if tid >= 0]
    clean_text = processor.tokenizer.decode(gen_only_valid, skip_special_tokens=True)
    print("\n===== ASSISTANT (clean) =====\n")
    print(clean_text)

    # 5.b Only after the last <|image_pad|> token
    image_pad_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
    ids_list = full_ids.tolist()
    try:
        last_pad_idx = len(ids_list) - 1 - ids_list[::-1].index(image_pad_id)
        start_after_skip = last_pad_idx + 1
    except ValueError:
        start_after_skip = 0
    sliced_after_skip = ids_list[start_after_skip:]
    # Also remove any residual <|image_pad|> that might appear later
    filtered_after_skip = [tid for tid in sliced_after_skip if tid != image_pad_id]
    raw_text_with_markers = processor.tokenizer.decode(filtered_after_skip, skip_special_tokens=False)
    print("\n===== RAW (after <|image_pad|>, with special tokens) =====\n")
    print(raw_text_with_markers)

    # # 5.c Token-by-token dump (ids + tokens)
    # dump_limit = full_ids.numel()
    # print("\n===== TOKEN DUMP (after <|image_pad|>, skip <|image_pad|>) =====")
    # for idx in range(start_after_skip, dump_limit):
    #     tid = int(full_ids[idx].item())
    #     if tid == image_pad_id:
    #         continue
    #     tok = processor.tokenizer.decode([tid], skip_special_tokens=False)
    #     print(f"[{idx:04d}] id={tid:<8} token={tok}")


if __name__ == "__main__":
    main()