vietmeagent / core /post_hoc_explainer.py
Dangindev's picture
Upload folder using huggingface_hub
b0ce04d verified
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel
import logging
logger = logging.getLogger(__name__)
class PostHocExplainer:
"""
Post-hoc explanation module for generating visual explanations
Implements heatmaps to show which image regions influenced the answer
"""
def __init__(self, clip_model, clip_processor=None, device='cuda'):
self.clip_model = clip_model
self.clip_processor = clip_processor
self.device = device
# Validate inputs
if self.clip_model is None:
raise ValueError("CLIP model cannot be None")
if self.clip_processor is None:
logger.warning("CLIP processor is None, some methods may not work")
# Set model to evaluation mode
self.clip_model.eval()
logger.info("PostHocExplainer initialized with CLIP model")
def generate_heatmap(self, image, question_text=None, method='attention_rollout'):
"""Generate heatmap showing important image regions for VQA"""
logger.info(f"Generating heatmap using method: {method}")
try:
if method == 'attention_rollout':
return self.generate_attention_rollout_heatmap(image, question_text)
elif method == 'gradient_based':
return self.generate_gradient_heatmap(image, question_text)
elif method == 'occlusion':
return self.generate_occlusion_heatmap(image, question_text)
else:
logger.warning(f"Unknown method {method}, using attention_rollout")
return self.generate_attention_rollout_heatmap(image, question_text)
except Exception as e:
logger.error(f"Heatmap generation failed: {e}")
logger.info("Using fallback center-focused heatmap")
return self.create_center_fallback_heatmap()
def generate_attention_rollout_heatmap(self, image, question_text=None):
"""Generate heatmap using attention rollout method"""
logger.info("Generating attention rollout heatmap")
try:
# Check if processor is available
if self.clip_processor is None:
raise ValueError("CLIP processor is required for attention rollout")
# Prepare inputs
if question_text is None:
question_text = "What is in this image?"
# Process image and text with truncation
inputs = self.clip_processor(
text=[question_text],
images=image,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77 # CLIP's maximum token length
).to(self.device)
logger.info("Running forward pass with attention outputs")
# Get attention weights
with torch.no_grad():
outputs = self.clip_model(**inputs, output_attentions=True)
# Try different ways to access vision attention
vision_attentions = None
# Method 1: Direct access
if hasattr(outputs, 'vision_model_output') and outputs.vision_model_output is not None:
if hasattr(outputs.vision_model_output, 'attentions'):
vision_attentions = outputs.vision_model_output.attentions
logger.info("Found vision attentions via vision_model_output")
# Method 2: Check if attentions are in main output
if vision_attentions is None and hasattr(outputs, 'attentions'):
vision_attentions = outputs.attentions
logger.info("Found attentions in main output")
# If still no attention, create fallback
if vision_attentions is None or len(vision_attentions) == 0:
logger.warning("No attention weights found, creating uniform attention")
attention_2d = torch.ones(7, 7) / 49
else:
# Extract attention from last layer
last_attention = vision_attentions[-1] # Last layer
# Average across heads and batch
attention_map = last_attention.mean(dim=1)[0] # [seq_len, seq_len]
# Get spatial attention (excluding CLS token)
spatial_attention = attention_map[1:, 1:] # Remove CLS token
# Reshape to spatial dimensions
patch_size = int(np.sqrt(spatial_attention.shape[0]))
if spatial_attention.shape[0] == patch_size * patch_size:
attention_2d = spatial_attention.mean(dim=1).reshape(patch_size, patch_size)
logger.info(f"Reshaped attention to {patch_size}x{patch_size}")
else:
logger.warning(f"Cannot reshape attention {spatial_attention.shape}, using uniform")
attention_2d = torch.ones(7, 7) / 49
# Resize to 224x224
attention_2d = F.interpolate(
attention_2d.unsqueeze(0).unsqueeze(0),
size=(224, 224),
mode='bilinear',
align_corners=False
).squeeze().cpu().numpy()
# Normalize to [0, 1]
attention_2d = (attention_2d - attention_2d.min()) / (attention_2d.max() - attention_2d.min() + 1e-8)
logger.info(f"Generated attention heatmap with shape {attention_2d.shape}")
return attention_2d
except Exception as e:
logger.warning(f"Attention rollout failed: {e}, using gradient method")
return self.generate_gradient_heatmap(image, question_text)
def generate_gradient_heatmap(self, image, question_text=None):
"""Generate heatmap using gradient-based method"""
logger.info("Generating gradient-based heatmap")
try:
if self.clip_processor is None:
raise ValueError("CLIP processor is required for gradient method")
if question_text is None:
question_text = "What is in this image?"
# Enable gradient computation
self.clip_model.train()
# Process inputs with truncation
inputs = self.clip_processor(
text=[question_text],
images=image,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77 # CLIP's maximum token length
).to(self.device)
# Require gradients for pixel values
inputs['pixel_values'].requires_grad_(True)
logger.info("Running forward pass for gradients")
# Forward pass
outputs = self.clip_model(**inputs)
# Get image-text similarity score
logits_per_image = outputs.logits_per_image[0, 0]
logger.info("Computing gradients")
# Backward pass
logits_per_image.backward()
# Get gradients
gradients = inputs['pixel_values'].grad[0] # [C, H, W]
# Create heatmap from gradients
heatmap = torch.norm(gradients, dim=0).cpu().numpy() # [H, W]
# Normalize
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
# Reset model to eval mode
self.clip_model.eval()
logger.info(f"Generated gradient heatmap with shape {heatmap.shape}")
return heatmap
except Exception as e:
logger.warning(f"Gradient method failed: {e}, using occlusion method")
return self.generate_occlusion_heatmap(image, question_text)
def generate_occlusion_heatmap(self, image, question_text=None, patch_size=32):
"""Generate heatmap using occlusion method"""
logger.info("Generating occlusion-based heatmap")
try:
if self.clip_processor is None:
raise ValueError("CLIP processor is required for occlusion method")
if question_text is None:
question_text = "What is in this image?"
# Convert to numpy for processing
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
# Resize to standard size
image_resized = cv2.resize(image_np, (224, 224))
image_pil = Image.fromarray(image_resized)
logger.info("Getting baseline score")
# Get baseline score
inputs_baseline = self.clip_processor(
text=[question_text],
images=image_pil,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77 # CLIP's maximum token length
).to(self.device)
with torch.no_grad():
baseline_output = self.clip_model(**inputs_baseline)
baseline_score = baseline_output.logits_per_image[0, 0].cpu().item()
logger.info(f"Baseline score: {baseline_score}")
# Create heatmap
heatmap = np.zeros((224, 224))
# Occlude different regions
num_patches = 224 // patch_size
logger.info(f"Testing {num_patches}x{num_patches} patches")
for y in range(0, 224, patch_size):
for x in range(0, 224, patch_size):
try:
# Create occluded image
occluded_image = image_resized.copy()
y_end = min(y + patch_size, 224)
x_end = min(x + patch_size, 224)
occluded_image[y:y_end, x:x_end] = 128 # Gray patch
# Get score with occlusion
occluded_pil = Image.fromarray(occluded_image)
inputs_occluded = self.clip_processor(
text=[question_text],
images=occluded_pil,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77 # CLIP's maximum token length
).to(self.device)
with torch.no_grad():
occluded_output = self.clip_model(**inputs_occluded)
occluded_score = occluded_output.logits_per_image[0, 0].cpu().item()
# Importance = baseline - occluded (higher drop = more important)
importance = baseline_score - occluded_score
heatmap[y:y_end, x:x_end] = importance
except Exception as e:
logger.warning(f"Occlusion patch ({x},{y}) failed: {e}")
continue
# Normalize heatmap
heatmap = np.maximum(heatmap, 0) # Keep only positive values
if heatmap.max() > 0:
heatmap = heatmap / heatmap.max()
logger.info(f"Generated occlusion heatmap with shape {heatmap.shape}")
return heatmap
except Exception as e:
logger.error(f"Occlusion method failed: {e}")
return self.create_center_fallback_heatmap()
def create_center_fallback_heatmap(self):
"""Create a center-focused fallback heatmap"""
logger.info("Creating fallback center-focused heatmap")
heatmap = np.zeros((224, 224))
center_y, center_x = 112, 112
for y in range(224):
for x in range(224):
distance = np.sqrt((y - center_y)**2 + (x - center_x)**2)
heatmap[y, x] = max(0, 1 - distance / 112)
return heatmap
def visualize_explanation(self, image, heatmap, title="VQA Explanation", save_path=None):
"""Visualize heatmap overlay on original image"""
try:
# Prepare original image
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
# Resize image to match heatmap
image_resized = cv2.resize(image_np, (heatmap.shape[1], heatmap.shape[0]))
image_resized = image_resized.astype(np.float32) / 255.0
# Create visualization
plt.figure(figsize=(15, 5))
# Original image
plt.subplot(1, 3, 1)
plt.imshow(image_resized)
plt.title("Original Image")
plt.axis('off')
# Heatmap
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='hot', interpolation='bilinear')
plt.title("Attention Heatmap")
plt.axis('off')
plt.colorbar()
# Overlay
plt.subplot(1, 3, 3)
plt.imshow(image_resized)
plt.imshow(heatmap, cmap='hot', alpha=0.6, interpolation='bilinear')
plt.title(title)
plt.axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"Visualization saved to {save_path}")
plt.close() # Close to prevent display in headless environment
return image_resized
except Exception as e:
logger.error(f"Visualization failed: {e}")
return None
class VietnameseExplanationGenerator:
"""Generate Vietnamese explanations for VQA results"""
def __init__(self, cultural_kb):
self.cultural_kb = cultural_kb
# Vietnamese explanation templates
self.templates = {
'food': "Trong ảnh có {object}, đây là {description}. {cultural_significance}",
'clothing': "Trang phục {object} trong ảnh thể hiện {cultural_significance}",
'architecture': "Kiến trúc {object} mang đặc trưng {description}",
'activity': "Hoạt động {object} có ý nghĩa {cultural_significance}",
'general': "Đối tượng {object} trong văn hóa Việt Nam {description}"
}
def generate_explanation(self, question, answer, cultural_objects, heatmap=None):
"""Generate Vietnamese cultural explanation"""
try:
explanations = []
# Base explanation
base_explanation = f"Câu trả lời '{answer}' được đưa ra dựa trên phân tích hình ảnh."
explanations.append(base_explanation)
# Cultural explanations
for obj in cultural_objects:
if obj in self.cultural_kb['objects']:
obj_data = self.cultural_kb['objects'][obj]
category = obj_data.get('category', 'general')
template = self.templates.get(category, self.templates['general'])
cultural_exp = template.format(
object=obj,
description=obj_data.get('description', ''),
cultural_significance=obj_data.get('cultural_significance', '')
)
explanations.append(cultural_exp)
# Visual attention explanation
if heatmap is not None:
attention_exp = self.generate_attention_explanation(heatmap)
explanations.append(attention_exp)
return " ".join(explanations)
except Exception as e:
logger.warning(f"Explanation generation failed: {e}")
return f"Phân tích hình ảnh cho câu hỏi: {question}"
def generate_attention_explanation(self, heatmap):
"""Generate explanation about visual attention"""
try:
# Calculate attention statistics
max_attention = np.max(heatmap)
mean_attention = np.mean(heatmap)
if max_attention > 0.8:
return "Mô hình tập trung cao độ vào một vùng cụ thể trong ảnh."
elif mean_attention > 0.5:
return "Mô hình phân tán sự chú ý trên nhiều vùng khác nhau."
else:
return "Mô hình có sự chú ý tương đối đều trên toàn bộ ảnh."
except Exception as e:
logger.warning(f"Attention explanation failed: {e}")
return "Phân tích sự chú ý của mô hình."