File size: 7,716 Bytes
167ea92 f46a35d 863dd32 f46a35d 0e15733 f46a35d 0e15733 f46a35d 0e15733 f46a35d 167ea92 863dd32 167ea92 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from PIL import Image
import streamlit as st
from SkinGPT import SkinGPTClassifier
import numpy as np
from torchvision import transforms
import os
class SkinGPTTester:
def __init__(self, model_path="finetuned_dermnet_version1.pth"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.classifier = SkinGPTClassifier()
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def visualize_attention(self, image_path):
"""Visualize attention maps from Q-Former"""
image = Image.open(image_path).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Get attention maps
_ = self.classifier.model.encode_image(image_tensor)
# Get attention from Q-Former
if self.classifier.model.q_former.last_attention is None:
print("Warning: No attention maps available. Make sure output_attentions=True in BERT config.")
return
# Get the last layer's attention
attention = self.classifier.model.q_former.last_attention[0][0] # shape: [num_tokens,]
# Print attention shape for debugging
print(f"Attention shape: {attention.shape}")
# Reshape attention to match image dimensions
# The attention shape should be [num_query_tokens + num_patches, num_query_tokens + num_patches]
# We want to visualize the attention from query tokens to image patches
num_query_tokens = self.classifier.model.q_former.num_query_tokens
attention_to_patches = attention[num_query_tokens:, :num_query_tokens].mean(dim=1)
# Calculate the number of patches
num_patches = attention_to_patches.shape[0]
h = w = int(math.sqrt(num_patches))
if h * w != num_patches:
print(f"Warning: Number of patches ({num_patches}) is not a perfect square")
# Use the closest square dimensions
h = w = int(math.ceil(math.sqrt(num_patches)))
# Pad the attention map to make it square
padded_attention = torch.zeros(h * w, device=attention_to_patches.device)
padded_attention[:num_patches] = attention_to_patches
attention_to_patches = padded_attention
# Reshape to 2D
attention_map = attention_to_patches.reshape(h, w)
# Plot
plt.figure(figsize=(15, 5))
# Original image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
# Attention map
plt.subplot(1, 3, 2)
plt.imshow(attention_map.cpu().numpy(), cmap='hot')
plt.title('Attention Map')
plt.axis('off')
# Overlay
plt.subplot(1, 3, 3)
plt.imshow(image)
plt.imshow(attention_map.cpu().numpy(), alpha=0.5, cmap='hot')
plt.title('Attention Overlay')
plt.axis('off')
plt.tight_layout()
plt.savefig('attention_visualization.png')
plt.close()
print(f"Attention visualization saved as 'attention_visualization.png'")
def check_feature_similarity(self, image_path1, image_path2):
"""Compare embeddings of two images"""
image1 = Image.open(image_path1).convert('RGB')
image2 = Image.open(image_path2).convert('RGB')
with torch.no_grad():
# Get embeddings
emb1 = self.classifier.model.encode_image(
self.transform(image1).unsqueeze(0).to(self.device)
)
emb2 = self.classifier.model.encode_image(
self.transform(image2).unsqueeze(0).to(self.device)
)
# Calculate cosine similarity
similarity = F.cosine_similarity(emb1.mean(dim=1), emb2.mean(dim=1))
# Print statistics
print(f"\nFeature Similarity Analysis:")
print(f"Image 1: {image_path1}")
print(f"Image 2: {image_path2}")
print(f"Cosine Similarity: {similarity.item():.4f}")
print(f"Embedding shapes: {emb1.shape}, {emb2.shape}")
print(f"Embedding means: {emb1.mean().item():.4f}, {emb2.mean().item():.4f}")
print(f"Embedding stds: {emb1.std().item():.4f}, {emb2.std().item():.4f}")
return similarity.item()
def validate_response(self, image_path, diagnosis):
"""Validate if diagnosis contains relevant visual features"""
image = Image.open(image_path).convert('RGB')
# Extract visual features using attention
with torch.no_grad():
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
attention = self.classifier.model.q_former.last_attention[0][0]
# Get regions with high attention
attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1)
high_attention_regions = (attention > attention.mean() + attention.std()).nonzero()
print(f"\nResponse Validation:")
print(f"Image: {image_path}")
print(f"Diagnosis: {diagnosis}")
print(f"Number of high-attention regions: {len(high_attention_regions)}")
return high_attention_regions
def debug_generation(self, image_path, prompt=None):
"""Debug the generation process"""
image = Image.open(image_path).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Get image embeddings
image_embeds = self.classifier.model.encode_image(image_tensor)
print("\nGeneration Debug Information:")
print(f"Image embedding shape: {image_embeds.shape}")
print(f"Image embedding mean: {image_embeds.mean().item():.4f}")
print(f"Image embedding std: {image_embeds.std().item():.4f}")
# Get diagnosis
result = self.classifier.predict(image, user_input=prompt)
print(f"\nGenerated Diagnosis:")
print(result["diagnosis"])
return result
def main():
# Initialize tester
tester = SkinGPTTester()
# Test image paths
test_image = "1.jpg"
similar_image = "2.jpg"
# Run all tests
print("Running comprehensive tests...")
# 1. Visualize attention
print("\n1. Visualizing attention maps...")
tester.visualize_attention(test_image)
# 2. Check feature similarity
print("\n2. Checking feature similarity...")
similarity = tester.check_feature_similarity(test_image, similar_image)
# 3. Debug generation
print("\n3. Debugging generation process...")
result = tester.debug_generation(test_image, "Describe the skin condition in detail.")
# 4. Validate response
print("\n4. Validating response...")
high_attention_regions = tester.validate_response(test_image, result["diagnosis"])
print("\nAll tests completed!")
if __name__ == "__main__":
main() |