import torch import torch.nn.functional as F import cv2 import numpy as np from PIL import Image import matplotlib.pyplot as plt from transformers import CLIPProcessor, CLIPModel import logging logger = logging.getLogger(__name__) class PostHocExplainer: """ Post-hoc explanation module for generating visual explanations Implements heatmaps to show which image regions influenced the answer """ def __init__(self, clip_model, clip_processor=None, device='cuda'): self.clip_model = clip_model self.clip_processor = clip_processor self.device = device # Validate inputs if self.clip_model is None: raise ValueError("CLIP model cannot be None") if self.clip_processor is None: logger.warning("CLIP processor is None, some methods may not work") # Set model to evaluation mode self.clip_model.eval() logger.info("PostHocExplainer initialized with CLIP model") def generate_heatmap(self, image, question_text=None, method='attention_rollout'): """Generate heatmap showing important image regions for VQA""" logger.info(f"Generating heatmap using method: {method}") try: if method == 'attention_rollout': return self.generate_attention_rollout_heatmap(image, question_text) elif method == 'gradient_based': return self.generate_gradient_heatmap(image, question_text) elif method == 'occlusion': return self.generate_occlusion_heatmap(image, question_text) else: logger.warning(f"Unknown method {method}, using attention_rollout") return self.generate_attention_rollout_heatmap(image, question_text) except Exception as e: logger.error(f"Heatmap generation failed: {e}") logger.info("Using fallback center-focused heatmap") return self.create_center_fallback_heatmap() def generate_attention_rollout_heatmap(self, image, question_text=None): """Generate heatmap using attention rollout method""" logger.info("Generating attention rollout heatmap") try: # Check if processor is available if self.clip_processor is None: raise ValueError("CLIP processor is required for attention rollout") # Prepare inputs if question_text is None: question_text = "What is in this image?" # Process image and text with truncation inputs = self.clip_processor( text=[question_text], images=image, return_tensors="pt", padding=True, truncation=True, max_length=77 # CLIP's maximum token length ).to(self.device) logger.info("Running forward pass with attention outputs") # Get attention weights with torch.no_grad(): outputs = self.clip_model(**inputs, output_attentions=True) # Try different ways to access vision attention vision_attentions = None # Method 1: Direct access if hasattr(outputs, 'vision_model_output') and outputs.vision_model_output is not None: if hasattr(outputs.vision_model_output, 'attentions'): vision_attentions = outputs.vision_model_output.attentions logger.info("Found vision attentions via vision_model_output") # Method 2: Check if attentions are in main output if vision_attentions is None and hasattr(outputs, 'attentions'): vision_attentions = outputs.attentions logger.info("Found attentions in main output") # If still no attention, create fallback if vision_attentions is None or len(vision_attentions) == 0: logger.warning("No attention weights found, creating uniform attention") attention_2d = torch.ones(7, 7) / 49 else: # Extract attention from last layer last_attention = vision_attentions[-1] # Last layer # Average across heads and batch attention_map = last_attention.mean(dim=1)[0] # [seq_len, seq_len] # Get spatial attention (excluding CLS token) spatial_attention = attention_map[1:, 1:] # Remove CLS token # Reshape to spatial dimensions patch_size = int(np.sqrt(spatial_attention.shape[0])) if spatial_attention.shape[0] == patch_size * patch_size: attention_2d = spatial_attention.mean(dim=1).reshape(patch_size, patch_size) logger.info(f"Reshaped attention to {patch_size}x{patch_size}") else: logger.warning(f"Cannot reshape attention {spatial_attention.shape}, using uniform") attention_2d = torch.ones(7, 7) / 49 # Resize to 224x224 attention_2d = F.interpolate( attention_2d.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False ).squeeze().cpu().numpy() # Normalize to [0, 1] attention_2d = (attention_2d - attention_2d.min()) / (attention_2d.max() - attention_2d.min() + 1e-8) logger.info(f"Generated attention heatmap with shape {attention_2d.shape}") return attention_2d except Exception as e: logger.warning(f"Attention rollout failed: {e}, using gradient method") return self.generate_gradient_heatmap(image, question_text) def generate_gradient_heatmap(self, image, question_text=None): """Generate heatmap using gradient-based method""" logger.info("Generating gradient-based heatmap") try: if self.clip_processor is None: raise ValueError("CLIP processor is required for gradient method") if question_text is None: question_text = "What is in this image?" # Enable gradient computation self.clip_model.train() # Process inputs with truncation inputs = self.clip_processor( text=[question_text], images=image, return_tensors="pt", padding=True, truncation=True, max_length=77 # CLIP's maximum token length ).to(self.device) # Require gradients for pixel values inputs['pixel_values'].requires_grad_(True) logger.info("Running forward pass for gradients") # Forward pass outputs = self.clip_model(**inputs) # Get image-text similarity score logits_per_image = outputs.logits_per_image[0, 0] logger.info("Computing gradients") # Backward pass logits_per_image.backward() # Get gradients gradients = inputs['pixel_values'].grad[0] # [C, H, W] # Create heatmap from gradients heatmap = torch.norm(gradients, dim=0).cpu().numpy() # [H, W] # Normalize heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) # Reset model to eval mode self.clip_model.eval() logger.info(f"Generated gradient heatmap with shape {heatmap.shape}") return heatmap except Exception as e: logger.warning(f"Gradient method failed: {e}, using occlusion method") return self.generate_occlusion_heatmap(image, question_text) def generate_occlusion_heatmap(self, image, question_text=None, patch_size=32): """Generate heatmap using occlusion method""" logger.info("Generating occlusion-based heatmap") try: if self.clip_processor is None: raise ValueError("CLIP processor is required for occlusion method") if question_text is None: question_text = "What is in this image?" # Convert to numpy for processing if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image # Resize to standard size image_resized = cv2.resize(image_np, (224, 224)) image_pil = Image.fromarray(image_resized) logger.info("Getting baseline score") # Get baseline score inputs_baseline = self.clip_processor( text=[question_text], images=image_pil, return_tensors="pt", padding=True, truncation=True, max_length=77 # CLIP's maximum token length ).to(self.device) with torch.no_grad(): baseline_output = self.clip_model(**inputs_baseline) baseline_score = baseline_output.logits_per_image[0, 0].cpu().item() logger.info(f"Baseline score: {baseline_score}") # Create heatmap heatmap = np.zeros((224, 224)) # Occlude different regions num_patches = 224 // patch_size logger.info(f"Testing {num_patches}x{num_patches} patches") for y in range(0, 224, patch_size): for x in range(0, 224, patch_size): try: # Create occluded image occluded_image = image_resized.copy() y_end = min(y + patch_size, 224) x_end = min(x + patch_size, 224) occluded_image[y:y_end, x:x_end] = 128 # Gray patch # Get score with occlusion occluded_pil = Image.fromarray(occluded_image) inputs_occluded = self.clip_processor( text=[question_text], images=occluded_pil, return_tensors="pt", padding=True, truncation=True, max_length=77 # CLIP's maximum token length ).to(self.device) with torch.no_grad(): occluded_output = self.clip_model(**inputs_occluded) occluded_score = occluded_output.logits_per_image[0, 0].cpu().item() # Importance = baseline - occluded (higher drop = more important) importance = baseline_score - occluded_score heatmap[y:y_end, x:x_end] = importance except Exception as e: logger.warning(f"Occlusion patch ({x},{y}) failed: {e}") continue # Normalize heatmap heatmap = np.maximum(heatmap, 0) # Keep only positive values if heatmap.max() > 0: heatmap = heatmap / heatmap.max() logger.info(f"Generated occlusion heatmap with shape {heatmap.shape}") return heatmap except Exception as e: logger.error(f"Occlusion method failed: {e}") return self.create_center_fallback_heatmap() def create_center_fallback_heatmap(self): """Create a center-focused fallback heatmap""" logger.info("Creating fallback center-focused heatmap") heatmap = np.zeros((224, 224)) center_y, center_x = 112, 112 for y in range(224): for x in range(224): distance = np.sqrt((y - center_y)**2 + (x - center_x)**2) heatmap[y, x] = max(0, 1 - distance / 112) return heatmap def visualize_explanation(self, image, heatmap, title="VQA Explanation", save_path=None): """Visualize heatmap overlay on original image""" try: # Prepare original image if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image # Resize image to match heatmap image_resized = cv2.resize(image_np, (heatmap.shape[1], heatmap.shape[0])) image_resized = image_resized.astype(np.float32) / 255.0 # Create visualization plt.figure(figsize=(15, 5)) # Original image plt.subplot(1, 3, 1) plt.imshow(image_resized) plt.title("Original Image") plt.axis('off') # Heatmap plt.subplot(1, 3, 2) plt.imshow(heatmap, cmap='hot', interpolation='bilinear') plt.title("Attention Heatmap") plt.axis('off') plt.colorbar() # Overlay plt.subplot(1, 3, 3) plt.imshow(image_resized) plt.imshow(heatmap, cmap='hot', alpha=0.6, interpolation='bilinear') plt.title(title) plt.axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') logger.info(f"Visualization saved to {save_path}") plt.close() # Close to prevent display in headless environment return image_resized except Exception as e: logger.error(f"Visualization failed: {e}") return None class VietnameseExplanationGenerator: """Generate Vietnamese explanations for VQA results""" def __init__(self, cultural_kb): self.cultural_kb = cultural_kb # Vietnamese explanation templates self.templates = { 'food': "Trong ảnh có {object}, đây là {description}. {cultural_significance}", 'clothing': "Trang phục {object} trong ảnh thể hiện {cultural_significance}", 'architecture': "Kiến trúc {object} mang đặc trưng {description}", 'activity': "Hoạt động {object} có ý nghĩa {cultural_significance}", 'general': "Đối tượng {object} trong văn hóa Việt Nam {description}" } def generate_explanation(self, question, answer, cultural_objects, heatmap=None): """Generate Vietnamese cultural explanation""" try: explanations = [] # Base explanation base_explanation = f"Câu trả lời '{answer}' được đưa ra dựa trên phân tích hình ảnh." explanations.append(base_explanation) # Cultural explanations for obj in cultural_objects: if obj in self.cultural_kb['objects']: obj_data = self.cultural_kb['objects'][obj] category = obj_data.get('category', 'general') template = self.templates.get(category, self.templates['general']) cultural_exp = template.format( object=obj, description=obj_data.get('description', ''), cultural_significance=obj_data.get('cultural_significance', '') ) explanations.append(cultural_exp) # Visual attention explanation if heatmap is not None: attention_exp = self.generate_attention_explanation(heatmap) explanations.append(attention_exp) return " ".join(explanations) except Exception as e: logger.warning(f"Explanation generation failed: {e}") return f"Phân tích hình ảnh cho câu hỏi: {question}" def generate_attention_explanation(self, heatmap): """Generate explanation about visual attention""" try: # Calculate attention statistics max_attention = np.max(heatmap) mean_attention = np.mean(heatmap) if max_attention > 0.8: return "Mô hình tập trung cao độ vào một vùng cụ thể trong ảnh." elif mean_attention > 0.5: return "Mô hình phân tán sự chú ý trên nhiều vùng khác nhau." else: return "Mô hình có sự chú ý tương đối đều trên toàn bộ ảnh." except Exception as e: logger.warning(f"Attention explanation failed: {e}") return "Phân tích sự chú ý của mô hình."