reka-edge-2603 / example.py
donovanOng92's picture
Update example.py
72d6425 verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.12.7"
# dependencies = [
# "transformers==4.57.3",
# "torch==2.9.1",
# "torchvision==0.24.1",
# "pillow==12.1.1",
# "tiktoken==0.12.0",
# "imageio==2.37.2",
# "einops==0.8.2",
# "av",
# ]
# ///
"""Run Reka Edge 7B on an image or video + text query.
Usage:
uv run example.py --image photo.jpg
uv run example.py --image photo.jpg --prompt "What is this?"
uv run example.py --video media/dashcam.mp4 --prompt "Is this person falling asleep?"
uv run example.py --image photo.jpg --model /path/to/local/checkpoint
"""
import argparse
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
def main() -> None:
parser = argparse.ArgumentParser(description="Reka Edge 7B inference")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--image", help="Path to an image file")
group.add_argument("--video", help="Path to a video file")
parser.add_argument(
"--prompt",
default="Describe what you see in detail.",
help="Text prompt (default: 'Describe what you see in detail.')",
)
parser.add_argument(
"--model",
default=".",
help="Model ID or local path (default: current directory)",
)
args = parser.parse_args()
# Pick best available device
if torch.cuda.is_available():
device = torch.device("cuda")
model_dtype = torch.float16
else:
mps_ok = False
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_built() and torch.backends.mps.is_available():
torch.zeros(1, device="mps") # verify MPS actually works
mps_ok = True
if mps_ok:
device = torch.device("mps")
model_dtype = torch.float16
else:
device = torch.device("cpu")
model_dtype = torch.float32
# Load processor and model
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
args.model,
trust_remote_code=True,
torch_dtype=model_dtype,
).eval()
model = model.to(device)
# Prepare an image or video + text query
if args.video:
media_entry = {"type": "video", "video": args.video}
else:
media_entry = {"type": "image", "image": args.image}
messages = [
{
"role": "user",
"content": [
media_entry,
{"type": "text", "text": args.prompt},
],
}
]
# Tokenize using the chat template
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
# Move tensors to device
for key, val in inputs.items():
if isinstance(val, torch.Tensor):
if val.is_floating_point():
inputs[key] = val.to(device=device, dtype=model_dtype)
else:
inputs[key] = val.to(device=device)
# Generate
with torch.inference_mode():
# Stop on <sep> token (end-of-turn) in addition to default EOS
sep_token_id = processor.tokenizer.convert_tokens_to_ids("<sep>")
output_ids = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
eos_token_id=[processor.tokenizer.eos_token_id, sep_token_id],
)
# Decode only the generated tokens
input_len = inputs["input_ids"].shape[1]
new_tokens = output_ids[0, input_len:]
output_text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)
# Strip any trailing <sep> turn-boundary marker
output_text = output_text.replace("<sep>", "").strip()
print(output_text)
if __name__ == "__main__":
main()