import torch import torch.nn.functional as F import matplotlib.pyplot as plt import math from PIL import Image import streamlit as st from SkinGPT import SkinGPTClassifier import numpy as np from torchvision import transforms import os class SkinGPTTester: def __init__(self, model_path="finetuned_dermnet_version1.pth"): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.classifier = SkinGPTClassifier() self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def visualize_attention(self, image_path): """Visualize attention maps from Q-Former""" image = Image.open(image_path).convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): # Get attention maps _ = self.classifier.model.encode_image(image_tensor) # Get attention from Q-Former if self.classifier.model.q_former.last_attention is None: print("Warning: No attention maps available. Make sure output_attentions=True in BERT config.") return # Get the last layer's attention attention = self.classifier.model.q_former.last_attention[0][0] # shape: [num_tokens,] # Print attention shape for debugging print(f"Attention shape: {attention.shape}") # Reshape attention to match image dimensions # The attention shape should be [num_query_tokens + num_patches, num_query_tokens + num_patches] # We want to visualize the attention from query tokens to image patches num_query_tokens = self.classifier.model.q_former.num_query_tokens attention_to_patches = attention[num_query_tokens:, :num_query_tokens].mean(dim=1) # Calculate the number of patches num_patches = attention_to_patches.shape[0] h = w = int(math.sqrt(num_patches)) if h * w != num_patches: print(f"Warning: Number of patches ({num_patches}) is not a perfect square") # Use the closest square dimensions h = w = int(math.ceil(math.sqrt(num_patches))) # Pad the attention map to make it square padded_attention = torch.zeros(h * w, device=attention_to_patches.device) padded_attention[:num_patches] = attention_to_patches attention_to_patches = padded_attention # Reshape to 2D attention_map = attention_to_patches.reshape(h, w) # Plot plt.figure(figsize=(15, 5)) # Original image plt.subplot(1, 3, 1) plt.imshow(image) plt.title('Original Image') plt.axis('off') # Attention map plt.subplot(1, 3, 2) plt.imshow(attention_map.cpu().numpy(), cmap='hot') plt.title('Attention Map') plt.axis('off') # Overlay plt.subplot(1, 3, 3) plt.imshow(image) plt.imshow(attention_map.cpu().numpy(), alpha=0.5, cmap='hot') plt.title('Attention Overlay') plt.axis('off') plt.tight_layout() plt.savefig('attention_visualization.png') plt.close() print(f"Attention visualization saved as 'attention_visualization.png'") def check_feature_similarity(self, image_path1, image_path2): """Compare embeddings of two images""" image1 = Image.open(image_path1).convert('RGB') image2 = Image.open(image_path2).convert('RGB') with torch.no_grad(): # Get embeddings emb1 = self.classifier.model.encode_image( self.transform(image1).unsqueeze(0).to(self.device) ) emb2 = self.classifier.model.encode_image( self.transform(image2).unsqueeze(0).to(self.device) ) # Calculate cosine similarity similarity = F.cosine_similarity(emb1.mean(dim=1), emb2.mean(dim=1)) # Print statistics print(f"\nFeature Similarity Analysis:") print(f"Image 1: {image_path1}") print(f"Image 2: {image_path2}") print(f"Cosine Similarity: {similarity.item():.4f}") print(f"Embedding shapes: {emb1.shape}, {emb2.shape}") print(f"Embedding means: {emb1.mean().item():.4f}, {emb2.mean().item():.4f}") print(f"Embedding stds: {emb1.std().item():.4f}, {emb2.std().item():.4f}") return similarity.item() def validate_response(self, image_path, diagnosis): """Validate if diagnosis contains relevant visual features""" image = Image.open(image_path).convert('RGB') # Extract visual features using attention with torch.no_grad(): image_tensor = self.transform(image).unsqueeze(0).to(self.device) attention = self.classifier.model.q_former.last_attention[0][0] # Get regions with high attention attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1) high_attention_regions = (attention > attention.mean() + attention.std()).nonzero() print(f"\nResponse Validation:") print(f"Image: {image_path}") print(f"Diagnosis: {diagnosis}") print(f"Number of high-attention regions: {len(high_attention_regions)}") return high_attention_regions def debug_generation(self, image_path, prompt=None): """Debug the generation process""" image = Image.open(image_path).convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): # Get image embeddings image_embeds = self.classifier.model.encode_image(image_tensor) print("\nGeneration Debug Information:") print(f"Image embedding shape: {image_embeds.shape}") print(f"Image embedding mean: {image_embeds.mean().item():.4f}") print(f"Image embedding std: {image_embeds.std().item():.4f}") # Get diagnosis result = self.classifier.predict(image, user_input=prompt) print(f"\nGenerated Diagnosis:") print(result["diagnosis"]) return result def main(): # Initialize tester tester = SkinGPTTester() # Test image paths test_image = "1.jpg" similar_image = "2.jpg" # Run all tests print("Running comprehensive tests...") # 1. Visualize attention print("\n1. Visualizing attention maps...") tester.visualize_attention(test_image) # 2. Check feature similarity print("\n2. Checking feature similarity...") similarity = tester.check_feature_similarity(test_image, similar_image) # 3. Debug generation print("\n3. Debugging generation process...") result = tester.debug_generation(test_image, "Describe the skin condition in detail.") # 4. Validate response print("\n4. Validating response...") high_attention_regions = tester.validate_response(test_image, result["diagnosis"]) print("\nAll tests completed!") if __name__ == "__main__": main()