| # Deepfake Explainer Vision Model with Multiple Image Support | |
| This model is a fine-tuned version of Llama 3.2 Vision that analyzes multiple images for signs of deepfakes. | |
| ## Important: Cross-Attention Dimension Fix | |
| This model requires a special handling for cross-attention masks. When loading the model for inference, | |
| make sure to fix the cross-attention mask dimensions as follows: | |
| ```python | |
| # Sample code for loading and using this model | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import torch | |
| from PIL import Image | |
| # Load base model and processor | |
| base_model_id = "unsloth/llama-3.2-11b-vision-instruct" | |
| processor = AutoProcessor.from_pretrained(base_model_id) | |
| model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto") | |
| # Load this adapter | |
| adapter_id = "saakshigupta/deepfake-explainer-2" | |
| model = PeftModel.from_pretrained(model, adapter_id) | |
| # Function to fix cross-attention masks | |
| def fix_processor_outputs(inputs): | |
| if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape: | |
| batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape | |
| visual_features = 6404 # Critical dimension | |
| new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles), | |
| device=inputs['cross_attention_mask'].device) | |
| inputs['cross_attention_mask'] = new_mask | |
| return inputs | |
| # Function to process multiple images | |
| def process_multiple_images(original_image, cam_image, cam_overlay, comparison_image, query): | |
| # Process with all four images | |
| # Note: This is a simplified approach and may need adaptation based on model capabilities | |
| inputs = processor( | |
| images=[original_image, cam_image, cam_overlay, comparison_image], | |
| text=query, | |
| return_tensors="pt" | |
| ) | |
| # Fix cross-attention mask | |
| inputs = fix_processor_outputs(inputs) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} | |
| # Generate output | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, max_new_tokens=500) | |
| response = processor.decode(output_ids[0], skip_special_tokens=True) | |
| return response | |
| # Example usage | |
| original_image = Image.open("path/to/original.jpg").convert("RGB") | |
| cam_image = Image.open("path/to/cam.jpg").convert("RGB") | |
| cam_overlay = Image.open("path/to/overlay.jpg").convert("RGB") | |
| comparison_image = Image.open("path/to/comparison.jpg").convert("RGB") | |
| query = "Analyze these images and explain if they show a deepfake." | |
| response = process_multiple_images(original_image, cam_image, cam_overlay, comparison_image, query) | |
| print(response) | |
| ``` | |
| This model was fine-tuned specifically for deepfake detection and explanation using multiple image inputs. | |