Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from transformers import CLIPProcessor, CLIPModel | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class PostHocExplainer: | |
| """ | |
| Post-hoc explanation module for generating visual explanations | |
| Implements heatmaps to show which image regions influenced the answer | |
| """ | |
| def __init__(self, clip_model, clip_processor=None, device='cuda'): | |
| self.clip_model = clip_model | |
| self.clip_processor = clip_processor | |
| self.device = device | |
| # Validate inputs | |
| if self.clip_model is None: | |
| raise ValueError("CLIP model cannot be None") | |
| if self.clip_processor is None: | |
| logger.warning("CLIP processor is None, some methods may not work") | |
| # Set model to evaluation mode | |
| self.clip_model.eval() | |
| logger.info("PostHocExplainer initialized with CLIP model") | |
| def generate_heatmap(self, image, question_text=None, method='attention_rollout'): | |
| """Generate heatmap showing important image regions for VQA""" | |
| logger.info(f"Generating heatmap using method: {method}") | |
| try: | |
| if method == 'attention_rollout': | |
| return self.generate_attention_rollout_heatmap(image, question_text) | |
| elif method == 'gradient_based': | |
| return self.generate_gradient_heatmap(image, question_text) | |
| elif method == 'occlusion': | |
| return self.generate_occlusion_heatmap(image, question_text) | |
| else: | |
| logger.warning(f"Unknown method {method}, using attention_rollout") | |
| return self.generate_attention_rollout_heatmap(image, question_text) | |
| except Exception as e: | |
| logger.error(f"Heatmap generation failed: {e}") | |
| logger.info("Using fallback center-focused heatmap") | |
| return self.create_center_fallback_heatmap() | |
| def generate_attention_rollout_heatmap(self, image, question_text=None): | |
| """Generate heatmap using attention rollout method""" | |
| logger.info("Generating attention rollout heatmap") | |
| try: | |
| # Check if processor is available | |
| if self.clip_processor is None: | |
| raise ValueError("CLIP processor is required for attention rollout") | |
| # Prepare inputs | |
| if question_text is None: | |
| question_text = "What is in this image?" | |
| # Process image and text with truncation | |
| inputs = self.clip_processor( | |
| text=[question_text], | |
| images=image, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77 # CLIP's maximum token length | |
| ).to(self.device) | |
| logger.info("Running forward pass with attention outputs") | |
| # Get attention weights | |
| with torch.no_grad(): | |
| outputs = self.clip_model(**inputs, output_attentions=True) | |
| # Try different ways to access vision attention | |
| vision_attentions = None | |
| # Method 1: Direct access | |
| if hasattr(outputs, 'vision_model_output') and outputs.vision_model_output is not None: | |
| if hasattr(outputs.vision_model_output, 'attentions'): | |
| vision_attentions = outputs.vision_model_output.attentions | |
| logger.info("Found vision attentions via vision_model_output") | |
| # Method 2: Check if attentions are in main output | |
| if vision_attentions is None and hasattr(outputs, 'attentions'): | |
| vision_attentions = outputs.attentions | |
| logger.info("Found attentions in main output") | |
| # If still no attention, create fallback | |
| if vision_attentions is None or len(vision_attentions) == 0: | |
| logger.warning("No attention weights found, creating uniform attention") | |
| attention_2d = torch.ones(7, 7) / 49 | |
| else: | |
| # Extract attention from last layer | |
| last_attention = vision_attentions[-1] # Last layer | |
| # Average across heads and batch | |
| attention_map = last_attention.mean(dim=1)[0] # [seq_len, seq_len] | |
| # Get spatial attention (excluding CLS token) | |
| spatial_attention = attention_map[1:, 1:] # Remove CLS token | |
| # Reshape to spatial dimensions | |
| patch_size = int(np.sqrt(spatial_attention.shape[0])) | |
| if spatial_attention.shape[0] == patch_size * patch_size: | |
| attention_2d = spatial_attention.mean(dim=1).reshape(patch_size, patch_size) | |
| logger.info(f"Reshaped attention to {patch_size}x{patch_size}") | |
| else: | |
| logger.warning(f"Cannot reshape attention {spatial_attention.shape}, using uniform") | |
| attention_2d = torch.ones(7, 7) / 49 | |
| # Resize to 224x224 | |
| attention_2d = F.interpolate( | |
| attention_2d.unsqueeze(0).unsqueeze(0), | |
| size=(224, 224), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze().cpu().numpy() | |
| # Normalize to [0, 1] | |
| attention_2d = (attention_2d - attention_2d.min()) / (attention_2d.max() - attention_2d.min() + 1e-8) | |
| logger.info(f"Generated attention heatmap with shape {attention_2d.shape}") | |
| return attention_2d | |
| except Exception as e: | |
| logger.warning(f"Attention rollout failed: {e}, using gradient method") | |
| return self.generate_gradient_heatmap(image, question_text) | |
| def generate_gradient_heatmap(self, image, question_text=None): | |
| """Generate heatmap using gradient-based method""" | |
| logger.info("Generating gradient-based heatmap") | |
| try: | |
| if self.clip_processor is None: | |
| raise ValueError("CLIP processor is required for gradient method") | |
| if question_text is None: | |
| question_text = "What is in this image?" | |
| # Enable gradient computation | |
| self.clip_model.train() | |
| # Process inputs with truncation | |
| inputs = self.clip_processor( | |
| text=[question_text], | |
| images=image, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77 # CLIP's maximum token length | |
| ).to(self.device) | |
| # Require gradients for pixel values | |
| inputs['pixel_values'].requires_grad_(True) | |
| logger.info("Running forward pass for gradients") | |
| # Forward pass | |
| outputs = self.clip_model(**inputs) | |
| # Get image-text similarity score | |
| logits_per_image = outputs.logits_per_image[0, 0] | |
| logger.info("Computing gradients") | |
| # Backward pass | |
| logits_per_image.backward() | |
| # Get gradients | |
| gradients = inputs['pixel_values'].grad[0] # [C, H, W] | |
| # Create heatmap from gradients | |
| heatmap = torch.norm(gradients, dim=0).cpu().numpy() # [H, W] | |
| # Normalize | |
| heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) | |
| # Reset model to eval mode | |
| self.clip_model.eval() | |
| logger.info(f"Generated gradient heatmap with shape {heatmap.shape}") | |
| return heatmap | |
| except Exception as e: | |
| logger.warning(f"Gradient method failed: {e}, using occlusion method") | |
| return self.generate_occlusion_heatmap(image, question_text) | |
| def generate_occlusion_heatmap(self, image, question_text=None, patch_size=32): | |
| """Generate heatmap using occlusion method""" | |
| logger.info("Generating occlusion-based heatmap") | |
| try: | |
| if self.clip_processor is None: | |
| raise ValueError("CLIP processor is required for occlusion method") | |
| if question_text is None: | |
| question_text = "What is in this image?" | |
| # Convert to numpy for processing | |
| if isinstance(image, Image.Image): | |
| image_np = np.array(image) | |
| else: | |
| image_np = image | |
| # Resize to standard size | |
| image_resized = cv2.resize(image_np, (224, 224)) | |
| image_pil = Image.fromarray(image_resized) | |
| logger.info("Getting baseline score") | |
| # Get baseline score | |
| inputs_baseline = self.clip_processor( | |
| text=[question_text], | |
| images=image_pil, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77 # CLIP's maximum token length | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| baseline_output = self.clip_model(**inputs_baseline) | |
| baseline_score = baseline_output.logits_per_image[0, 0].cpu().item() | |
| logger.info(f"Baseline score: {baseline_score}") | |
| # Create heatmap | |
| heatmap = np.zeros((224, 224)) | |
| # Occlude different regions | |
| num_patches = 224 // patch_size | |
| logger.info(f"Testing {num_patches}x{num_patches} patches") | |
| for y in range(0, 224, patch_size): | |
| for x in range(0, 224, patch_size): | |
| try: | |
| # Create occluded image | |
| occluded_image = image_resized.copy() | |
| y_end = min(y + patch_size, 224) | |
| x_end = min(x + patch_size, 224) | |
| occluded_image[y:y_end, x:x_end] = 128 # Gray patch | |
| # Get score with occlusion | |
| occluded_pil = Image.fromarray(occluded_image) | |
| inputs_occluded = self.clip_processor( | |
| text=[question_text], | |
| images=occluded_pil, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77 # CLIP's maximum token length | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| occluded_output = self.clip_model(**inputs_occluded) | |
| occluded_score = occluded_output.logits_per_image[0, 0].cpu().item() | |
| # Importance = baseline - occluded (higher drop = more important) | |
| importance = baseline_score - occluded_score | |
| heatmap[y:y_end, x:x_end] = importance | |
| except Exception as e: | |
| logger.warning(f"Occlusion patch ({x},{y}) failed: {e}") | |
| continue | |
| # Normalize heatmap | |
| heatmap = np.maximum(heatmap, 0) # Keep only positive values | |
| if heatmap.max() > 0: | |
| heatmap = heatmap / heatmap.max() | |
| logger.info(f"Generated occlusion heatmap with shape {heatmap.shape}") | |
| return heatmap | |
| except Exception as e: | |
| logger.error(f"Occlusion method failed: {e}") | |
| return self.create_center_fallback_heatmap() | |
| def create_center_fallback_heatmap(self): | |
| """Create a center-focused fallback heatmap""" | |
| logger.info("Creating fallback center-focused heatmap") | |
| heatmap = np.zeros((224, 224)) | |
| center_y, center_x = 112, 112 | |
| for y in range(224): | |
| for x in range(224): | |
| distance = np.sqrt((y - center_y)**2 + (x - center_x)**2) | |
| heatmap[y, x] = max(0, 1 - distance / 112) | |
| return heatmap | |
| def visualize_explanation(self, image, heatmap, title="VQA Explanation", save_path=None): | |
| """Visualize heatmap overlay on original image""" | |
| try: | |
| # Prepare original image | |
| if isinstance(image, Image.Image): | |
| image_np = np.array(image) | |
| else: | |
| image_np = image | |
| # Resize image to match heatmap | |
| image_resized = cv2.resize(image_np, (heatmap.shape[1], heatmap.shape[0])) | |
| image_resized = image_resized.astype(np.float32) / 255.0 | |
| # Create visualization | |
| plt.figure(figsize=(15, 5)) | |
| # Original image | |
| plt.subplot(1, 3, 1) | |
| plt.imshow(image_resized) | |
| plt.title("Original Image") | |
| plt.axis('off') | |
| # Heatmap | |
| plt.subplot(1, 3, 2) | |
| plt.imshow(heatmap, cmap='hot', interpolation='bilinear') | |
| plt.title("Attention Heatmap") | |
| plt.axis('off') | |
| plt.colorbar() | |
| # Overlay | |
| plt.subplot(1, 3, 3) | |
| plt.imshow(image_resized) | |
| plt.imshow(heatmap, cmap='hot', alpha=0.6, interpolation='bilinear') | |
| plt.title(title) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| logger.info(f"Visualization saved to {save_path}") | |
| plt.close() # Close to prevent display in headless environment | |
| return image_resized | |
| except Exception as e: | |
| logger.error(f"Visualization failed: {e}") | |
| return None | |
| class VietnameseExplanationGenerator: | |
| """Generate Vietnamese explanations for VQA results""" | |
| def __init__(self, cultural_kb): | |
| self.cultural_kb = cultural_kb | |
| # Vietnamese explanation templates | |
| self.templates = { | |
| 'food': "Trong ảnh có {object}, đây là {description}. {cultural_significance}", | |
| 'clothing': "Trang phục {object} trong ảnh thể hiện {cultural_significance}", | |
| 'architecture': "Kiến trúc {object} mang đặc trưng {description}", | |
| 'activity': "Hoạt động {object} có ý nghĩa {cultural_significance}", | |
| 'general': "Đối tượng {object} trong văn hóa Việt Nam {description}" | |
| } | |
| def generate_explanation(self, question, answer, cultural_objects, heatmap=None): | |
| """Generate Vietnamese cultural explanation""" | |
| try: | |
| explanations = [] | |
| # Base explanation | |
| base_explanation = f"Câu trả lời '{answer}' được đưa ra dựa trên phân tích hình ảnh." | |
| explanations.append(base_explanation) | |
| # Cultural explanations | |
| for obj in cultural_objects: | |
| if obj in self.cultural_kb['objects']: | |
| obj_data = self.cultural_kb['objects'][obj] | |
| category = obj_data.get('category', 'general') | |
| template = self.templates.get(category, self.templates['general']) | |
| cultural_exp = template.format( | |
| object=obj, | |
| description=obj_data.get('description', ''), | |
| cultural_significance=obj_data.get('cultural_significance', '') | |
| ) | |
| explanations.append(cultural_exp) | |
| # Visual attention explanation | |
| if heatmap is not None: | |
| attention_exp = self.generate_attention_explanation(heatmap) | |
| explanations.append(attention_exp) | |
| return " ".join(explanations) | |
| except Exception as e: | |
| logger.warning(f"Explanation generation failed: {e}") | |
| return f"Phân tích hình ảnh cho câu hỏi: {question}" | |
| def generate_attention_explanation(self, heatmap): | |
| """Generate explanation about visual attention""" | |
| try: | |
| # Calculate attention statistics | |
| max_attention = np.max(heatmap) | |
| mean_attention = np.mean(heatmap) | |
| if max_attention > 0.8: | |
| return "Mô hình tập trung cao độ vào một vùng cụ thể trong ảnh." | |
| elif mean_attention > 0.5: | |
| return "Mô hình phân tán sự chú ý trên nhiều vùng khác nhau." | |
| else: | |
| return "Mô hình có sự chú ý tương đối đều trên toàn bộ ảnh." | |
| except Exception as e: | |
| logger.warning(f"Attention explanation failed: {e}") | |
| return "Phân tích sự chú ý của mô hình." | |