| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Run Reka Edge 7B on an image or video + text query. |
| | |
| | Usage: |
| | uv run example.py --image photo.jpg |
| | uv run example.py --image photo.jpg --prompt "What is this?" |
| | uv run example.py --video media/dashcam.mp4 --prompt "Is this person falling asleep?" |
| | uv run example.py --image photo.jpg --model /path/to/local/checkpoint |
| | """ |
| |
|
| | import argparse |
| |
|
| | import torch |
| | from transformers import AutoModelForImageTextToText, AutoProcessor |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser(description="Reka Edge 7B inference") |
| | group = parser.add_mutually_exclusive_group(required=True) |
| | group.add_argument("--image", help="Path to an image file") |
| | group.add_argument("--video", help="Path to a video file") |
| | parser.add_argument( |
| | "--prompt", |
| | default="Describe what you see in detail.", |
| | help="Text prompt (default: 'Describe what you see in detail.')", |
| | ) |
| | parser.add_argument( |
| | "--model", |
| | default=".", |
| | help="Model ID or local path (default: current directory)", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | model_dtype = torch.float16 |
| | else: |
| | mps_ok = False |
| | if getattr(torch.backends, "mps", None) and torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
| | torch.zeros(1, device="mps") |
| | mps_ok = True |
| |
|
| | if mps_ok: |
| | device = torch.device("mps") |
| | model_dtype = torch.float16 |
| | else: |
| | device = torch.device("cpu") |
| | model_dtype = torch.float32 |
| |
|
| | |
| | processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) |
| | model = AutoModelForImageTextToText.from_pretrained( |
| | args.model, |
| | trust_remote_code=True, |
| | torch_dtype=model_dtype, |
| | ).eval() |
| | model = model.to(device) |
| |
|
| | |
| | if args.video: |
| | media_entry = {"type": "video", "video": args.video} |
| | else: |
| | media_entry = {"type": "image", "image": args.image} |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | media_entry, |
| | {"type": "text", "text": args.prompt}, |
| | ], |
| | } |
| | ] |
| |
|
| | |
| | inputs = processor.apply_chat_template( |
| | messages, |
| | tokenize=True, |
| | add_generation_prompt=True, |
| | return_tensors="pt", |
| | return_dict=True, |
| | ) |
| |
|
| | |
| | for key, val in inputs.items(): |
| | if isinstance(val, torch.Tensor): |
| | if val.is_floating_point(): |
| | inputs[key] = val.to(device=device, dtype=model_dtype) |
| | else: |
| | inputs[key] = val.to(device=device) |
| |
|
| | |
| | with torch.inference_mode(): |
| | |
| | sep_token_id = processor.tokenizer.convert_tokens_to_ids("<sep>") |
| | output_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=256, |
| | do_sample=False, |
| | eos_token_id=[processor.tokenizer.eos_token_id, sep_token_id], |
| | ) |
| |
|
| | |
| | input_len = inputs["input_ids"].shape[1] |
| | new_tokens = output_ids[0, input_len:] |
| | output_text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| | |
| | output_text = output_text.replace("<sep>", "").strip() |
| | print(output_text) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|