#!/usr/bin/env python3 # /// script # requires-python = ">=3.12.7" # dependencies = [ # "transformers==4.57.3", # "torch==2.9.1", # "torchvision==0.24.1", # "pillow==12.1.1", # "tiktoken==0.12.0", # "imageio==2.37.2", # "einops==0.8.2", # "av", # ] # /// """Run Reka Edge 7B on an image or video + text query. Usage: uv run example.py --image photo.jpg uv run example.py --image photo.jpg --prompt "What is this?" uv run example.py --video media/dashcam.mp4 --prompt "Is this person falling asleep?" uv run example.py --image photo.jpg --model /path/to/local/checkpoint """ import argparse import torch from transformers import AutoModelForImageTextToText, AutoProcessor def main() -> None: parser = argparse.ArgumentParser(description="Reka Edge 7B inference") group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--image", help="Path to an image file") group.add_argument("--video", help="Path to a video file") parser.add_argument( "--prompt", default="Describe what you see in detail.", help="Text prompt (default: 'Describe what you see in detail.')", ) parser.add_argument( "--model", default=".", help="Model ID or local path (default: current directory)", ) args = parser.parse_args() # Pick best available device if torch.cuda.is_available(): device = torch.device("cuda") model_dtype = torch.float16 else: mps_ok = False if getattr(torch.backends, "mps", None) and torch.backends.mps.is_built() and torch.backends.mps.is_available(): torch.zeros(1, device="mps") # verify MPS actually works mps_ok = True if mps_ok: device = torch.device("mps") model_dtype = torch.float16 else: device = torch.device("cpu") model_dtype = torch.float32 # Load processor and model processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( args.model, trust_remote_code=True, torch_dtype=model_dtype, ).eval() model = model.to(device) # Prepare an image or video + text query if args.video: media_entry = {"type": "video", "video": args.video} else: media_entry = {"type": "image", "image": args.image} messages = [ { "role": "user", "content": [ media_entry, {"type": "text", "text": args.prompt}, ], } ] # Tokenize using the chat template inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ) # Move tensors to device for key, val in inputs.items(): if isinstance(val, torch.Tensor): if val.is_floating_point(): inputs[key] = val.to(device=device, dtype=model_dtype) else: inputs[key] = val.to(device=device) # Generate with torch.inference_mode(): # Stop on token (end-of-turn) in addition to default EOS sep_token_id = processor.tokenizer.convert_tokens_to_ids("") output_ids = model.generate( **inputs, max_new_tokens=256, do_sample=False, eos_token_id=[processor.tokenizer.eos_token_id, sep_token_id], ) # Decode only the generated tokens input_len = inputs["input_ids"].shape[1] new_tokens = output_ids[0, input_len:] output_text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True) # Strip any trailing turn-boundary marker output_text = output_text.replace("", "").strip() print(output_text) if __name__ == "__main__": main()