| """ |
| verify_prismatic.py |
| |
| Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate(). |
| """ |
|
|
| import time |
|
|
| import requests |
| import torch |
| from PIL import Image |
| from transformers import AutoModelForVision2Seq, AutoProcessor |
|
|
| |
| MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b" |
| DEFAULT_IMAGE_URL = ( |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" |
| ) |
|
|
| if "-prism-" in MODEL_PATH: |
| SAMPLE_PROMPTS_FOR_GENERATION = [ |
| "In: What is sitting in the coffee?\nOut:", |
| "In: What's the name of the food on the plate?\nOut:", |
| "In: caption.\nOut:", |
| "In: how many beinets..?\nOut:", |
| "In: Can you give me a lyrical description of the scene\nOut:", |
| ] |
| else: |
| SYSTEM_PROMPT = ( |
| "A chat between a curious user and an artificial intelligence assistant. " |
| "The assistant gives helpful, detailed, and polite answers to the user's questions." |
| ) |
| SAMPLE_PROMPTS_FOR_GENERATION = [ |
| f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:", |
| f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:", |
| f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:", |
| f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:", |
| f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:", |
| ] |
|
|
|
|
| @torch.inference_mode() |
| def verify_prismatic() -> None: |
| print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`") |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
| |
| print("[*] Instantiating Processor and Pretrained VLM") |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| print("[*] Loading in BF16 with Flash-Attention Enabled") |
| vlm = AutoModelForVision2Seq.from_pretrained( |
| MODEL_PATH, |
| attn_implementation="flash_attention_2", |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| ).to(device) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") |
| num_tokens, total_time = 0, 0.0 |
|
|
| print("[*] Iterating over Sample Prompts\n===\n") |
| for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION): |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) |
|
|
| |
| |
|
|
| |
| gen_ids = None |
| for _ in range(5): |
| start_time = time.time() |
| gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) |
| total_time += time.time() - start_time |
|
|
| gen_ids = gen_ids[0, inputs.input_ids.shape[1] :] |
| num_tokens += len(gen_ids) |
|
|
| |
| gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip() |
| print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n") |
|
|
| |
| print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }") |
|
|
|
|
| if __name__ == "__main__": |
| verify_prismatic() |
|
|