Spaces:
Sleeping
Sleeping
| from transformers import AutoProcessor | |
| from PIL import Image | |
| import os | |
| import torch | |
| import pickle | |
| ## ACTUAL INPUT CONSTRUCTION | |
| BASE_SPEAKER_LEN = 787 | |
| def joint_listener_input(processor, context_images, description, device): | |
| # Preliminaries | |
| img_dir = "tangram_pngs" | |
| raw_images = process_images(img_dir, context_images) | |
| target_anno = description.lower() | |
| prompt = construct_listener_full_prompt( | |
| processor, target_anno, 0, "verbose_instruction" | |
| ) | |
| # Listener processing | |
| outputs = processor( | |
| text=[prompt], | |
| images=[raw_images], | |
| return_tensors="pt" | |
| ).to(device) | |
| l_input_tokens = outputs['input_ids'][:, :-2] | |
| l_attn_mask = outputs['attention_mask'][:, :-2] | |
| l_attn_mask[(l_input_tokens == 0).bool()] = 0 | |
| images = outputs['pixel_values'] | |
| l_image_attn_mask = outputs['pixel_attention_mask'] | |
| # Speaker processing | |
| prompts = [] | |
| for i in range(10): | |
| prompt = construct_speaker_full_prompt(processor, description, i, "information_after") | |
| prompts.append(prompt) | |
| outputs = processor( | |
| text=prompts, | |
| images=[raw_images]*10, | |
| padding='longest', | |
| return_tensors="pt" | |
| ).to(device) | |
| s_input_tokens = outputs['input_ids'][:, :-1] | |
| s_attn_mask = outputs['attention_mask'][:, :-1] | |
| s_attn_mask[(s_input_tokens == 0).bool()] = 0 | |
| s_image_attn_mask = outputs['pixel_attention_mask'] | |
| s_target_tokens = outputs['input_ids'][:, 1:] | |
| s_target_mask = [] | |
| for i in range(10): | |
| curr_mask = create_speaker_caption_mask(outputs['input_ids'][i], s_attn_mask[i]) | |
| s_target_mask.append(curr_mask) | |
| s_target_mask = torch.stack(s_target_mask, dim=0) | |
| return images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens.unsqueeze(0), \ | |
| s_attn_mask.unsqueeze(0), s_image_attn_mask.unsqueeze(0), s_target_mask.unsqueeze(0), \ | |
| s_target_tokens.unsqueeze(0) | |
| def joint_speaker_input(processor, image_paths, target_path, device): | |
| # Get the prompt | |
| img_dir = "tangram_pngs" | |
| raw_images = process_images(img_dir, image_paths) | |
| target_idx = image_paths.index(target_path) | |
| base_prompt = construct_speaker_base_prompt(processor, target_idx, "information_after", process=True) | |
| # Create the basic input | |
| outputs = processor( | |
| text=[base_prompt], | |
| images=[raw_images], | |
| return_tensors="pt" | |
| ).to(device) | |
| input_tokens = outputs['input_ids'] | |
| attn_mask = outputs['attention_mask'] | |
| attn_mask[(input_tokens == 0).bool()] = 0 | |
| images = outputs['pixel_values'] | |
| image_attn_mask = outputs['pixel_attention_mask'] | |
| return input_tokens, attn_mask, images, image_attn_mask, torch.LongTensor([target_idx]).to(device) | |
| ## UTILITIES | |
| def get_processor(): | |
| checkpoint = "HuggingFaceM4/idefics2-8b" | |
| processor = AutoProcessor.from_pretrained(checkpoint, do_image_splitting=False, | |
| size={"longest_edge": 448, "shortest_edge": 224}) | |
| return processor | |
| def get_index_to_token(): | |
| index_to_token_path = "index_to_token.pkl" | |
| with open(index_to_token_path, 'rb') as f: | |
| index_to_token = pickle.load(f) | |
| return index_to_token | |
| def process_images(img_dir, context_images): | |
| raw_images = [] | |
| for img in context_images: | |
| image_path = os.path.join(img_dir, img) | |
| raw_image = Image.open(image_path).convert('RGB') | |
| raw_images.append(raw_image) | |
| return raw_images | |
| def create_speaker_caption_mask(all_token_ids, text_mask): | |
| # Overall token comp: pad + base + caption | |
| padding_tokens = torch.sum(all_token_ids == 0).item() | |
| caption_tokens = all_token_ids.shape[0] - (padding_tokens + BASE_SPEAKER_LEN) | |
| # Construct a mask where the last caption tokens are 1 | |
| target_mask = torch.zeros_like(text_mask) | |
| target_mask[-caption_tokens:] = 1 | |
| return target_mask.bool() | |
| def construct_listener_full_prompt(processor, target_anno, target_idx, comprehension_prompt_type="verbose_instruction"): | |
| target_anno = target_anno.lower().strip() | |
| messages = [] | |
| if comprehension_prompt_type == "verbose_instruction": | |
| # User side: Intro | |
| messages.append( | |
| { | |
| "role" : "user", | |
| "content" : [ | |
| {"type" : "text", "text" : "You will be presented with a sequence of 10 images and a caption describing exactly one of them. "}, | |
| {"type" : "text", "text" : "Your task is to guess which image the caption describes. "}, | |
| ] | |
| } | |
| ) | |
| # User side: Images | |
| for i in range(10): | |
| if i == 0: | |
| messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) | |
| else: | |
| messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) | |
| messages[0]["content"].append({"type" : "image"}) | |
| # User side: Caption | |
| messages[0]["content"].append({"type" : "text", "text" : f". Caption: {target_anno}"}) | |
| messages[0]["content"].append({"type" : "text", "text" : f" Does this caption describe Image 0, 1, 2, 3, 4, 5, 6, 7, 8 or 9?"}) | |
| # Model side: Guess | |
| messages.append( | |
| { | |
| "role" : "assistant", | |
| "content" : [ | |
| {"type" : "text", "text" : f"The caption describes Image {target_idx}"} | |
| ] | |
| } | |
| ) | |
| else: | |
| assert(False) | |
| return processor.apply_chat_template(messages, add_generation_prompt=False).strip() | |
| def construct_speaker_full_prompt(processor, target_anno, target_idx, | |
| generation_prompt_type="information_after"): | |
| messages = construct_speaker_base_prompt(processor, target_idx, generation_prompt_type) | |
| # Assistant response | |
| target_anno = target_anno.lower().strip() | |
| messages.append( | |
| { | |
| "role" : "assistant", | |
| "content" : [ | |
| {"type" : "text", "text" : target_anno} | |
| ] | |
| } | |
| ) | |
| return processor.apply_chat_template(messages, add_generation_prompt=False).strip() | |
| def construct_speaker_base_prompt(processor, target_idx, generation_prompt_type="information_after", process=False): | |
| messages = [] | |
| if generation_prompt_type == "information_after": | |
| # User side: Intro | |
| messages.append( | |
| { | |
| "role" : "user", | |
| "content" : [ | |
| {"type" : "text", "text" : "You will be presented with a sequence of 10 images and be assigned a target image. "}, | |
| {"type" : "text", "text" : "Your task is to produce a caption for your target image such that anyone could guess the image from your description. "}, | |
| ] | |
| } | |
| ) | |
| # User side: Images | |
| for i in range(10): | |
| if i == 0: | |
| messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) | |
| else: | |
| messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) | |
| messages[0]["content"].append({"type" : "image"}) | |
| # User side: Target assignment | |
| messages[0]["content"].append({"type" : "text", "text" : f". Your target image is Image {target_idx}. Produce your caption now."}) | |
| else: | |
| assert(False) | |
| if process: | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True).strip() | |
| return prompt | |
| else: | |
| return messages | |
| def process_idefics_listener_generation_input(speaker_context, captions, processor, img_dir, num_samples, device): | |
| # First construct the prompts | |
| prompts, raw_images = get_listener_generation_prompts(speaker_context, captions, num_samples, img_dir, processor) | |
| # Process the prompts | |
| listener_inputs = processor( | |
| text=prompts, | |
| images=raw_images, | |
| padding='longest', | |
| return_tensors='pt' | |
| ) | |
| input_tokens = listener_inputs['input_ids'][:, :-2].to(device) | |
| attn_mask = listener_inputs['attention_mask'][:, :-2].to(device) | |
| attn_mask[input_tokens == 0] = 0 | |
| images = listener_inputs['pixel_values'].to(device) | |
| image_attn_mask = listener_inputs['pixel_attention_mask'].to(device) | |
| return input_tokens, attn_mask, images, image_attn_mask | |
| def get_listener_generation_prompts(speaker_contexts, captions, num_samples, img_dir, processor): | |
| prompts = [] | |
| all_raw_images = [] | |
| for i, speaker_context in enumerate(speaker_contexts): | |
| raw_images = process_images(img_dir, speaker_context) | |
| for j in range(num_samples): | |
| curr_idx = i * num_samples + j | |
| caption = captions[curr_idx] | |
| prompt = construct_listener_full_prompt(processor, caption, 0, "verbose_instruction") | |
| prompts.append(prompt) | |
| all_raw_images.append(raw_images) | |
| return prompts, all_raw_images | |