File size: 3,929 Bytes
7d24555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72d6425
 
7d24555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.12.7"
# dependencies = [
#     "transformers==4.57.3",
#     "torch==2.9.1",
#     "torchvision==0.24.1",
#     "pillow==12.1.1",
#     "tiktoken==0.12.0",
#     "imageio==2.37.2",
#     "einops==0.8.2",
#     "av",
# ]
# ///
"""Run Reka Edge 7B on an image or video + text query.

Usage:
    uv run example.py --image photo.jpg
    uv run example.py --image photo.jpg --prompt "What is this?"
    uv run example.py --video media/dashcam.mp4 --prompt "Is this person falling asleep?"
    uv run example.py --image photo.jpg --model /path/to/local/checkpoint
"""

import argparse

import torch
from transformers import AutoModelForImageTextToText, AutoProcessor


def main() -> None:
    parser = argparse.ArgumentParser(description="Reka Edge 7B inference")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--image", help="Path to an image file")
    group.add_argument("--video", help="Path to a video file")
    parser.add_argument(
        "--prompt",
        default="Describe what you see in detail.",
        help="Text prompt (default: 'Describe what you see in detail.')",
    )
    parser.add_argument(
        "--model",
        default=".",
        help="Model ID or local path (default: current directory)",
    )
    args = parser.parse_args()

    # Pick best available device
    if torch.cuda.is_available():
        device = torch.device("cuda")
        model_dtype = torch.float16
    else:
        mps_ok = False
        if getattr(torch.backends, "mps", None) and torch.backends.mps.is_built() and torch.backends.mps.is_available():
            torch.zeros(1, device="mps")  # verify MPS actually works
            mps_ok = True

        if mps_ok:
            device = torch.device("mps")
            model_dtype = torch.float16
        else:
            device = torch.device("cpu")
            model_dtype = torch.float32

    # Load processor and model
    processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
    model = AutoModelForImageTextToText.from_pretrained(
        args.model,
        trust_remote_code=True,
        torch_dtype=model_dtype,
    ).eval()
    model = model.to(device)

    # Prepare an image or video + text query
    if args.video:
        media_entry = {"type": "video", "video": args.video}
    else:
        media_entry = {"type": "image", "image": args.image}
    messages = [
        {
            "role": "user",
            "content": [
                media_entry,
                {"type": "text", "text": args.prompt},
            ],
        }
    ]

    # Tokenize using the chat template
    inputs = processor.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
    )

    # Move tensors to device
    for key, val in inputs.items():
        if isinstance(val, torch.Tensor):
            if val.is_floating_point():
                inputs[key] = val.to(device=device, dtype=model_dtype)
            else:
                inputs[key] = val.to(device=device)

    # Generate
    with torch.inference_mode():
        # Stop on <sep> token (end-of-turn) in addition to default EOS
        sep_token_id = processor.tokenizer.convert_tokens_to_ids("<sep>")
        output_ids = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False,
            eos_token_id=[processor.tokenizer.eos_token_id, sep_token_id],
        )

    # Decode only the generated tokens
    input_len = inputs["input_ids"].shape[1]
    new_tokens = output_ids[0, input_len:]
    output_text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)
    # Strip any trailing <sep> turn-boundary marker
    output_text = output_text.replace("<sep>", "").strip()
    print(output_text)


if __name__ == "__main__":
    main()