SkinGPT / test.py
KeerthiVM's picture
Testing
863dd32
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from PIL import Image
import streamlit as st
from SkinGPT import SkinGPTClassifier
import numpy as np
from torchvision import transforms
import os
class SkinGPTTester:
def __init__(self, model_path="finetuned_dermnet_version1.pth"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.classifier = SkinGPTClassifier()
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def visualize_attention(self, image_path):
"""Visualize attention maps from Q-Former"""
image = Image.open(image_path).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Get attention maps
_ = self.classifier.model.encode_image(image_tensor)
# Get attention from Q-Former
if self.classifier.model.q_former.last_attention is None:
print("Warning: No attention maps available. Make sure output_attentions=True in BERT config.")
return
# Get the last layer's attention
attention = self.classifier.model.q_former.last_attention[0][0] # shape: [num_tokens,]
# Print attention shape for debugging
print(f"Attention shape: {attention.shape}")
# Reshape attention to match image dimensions
# The attention shape should be [num_query_tokens + num_patches, num_query_tokens + num_patches]
# We want to visualize the attention from query tokens to image patches
num_query_tokens = self.classifier.model.q_former.num_query_tokens
attention_to_patches = attention[num_query_tokens:, :num_query_tokens].mean(dim=1)
# Calculate the number of patches
num_patches = attention_to_patches.shape[0]
h = w = int(math.sqrt(num_patches))
if h * w != num_patches:
print(f"Warning: Number of patches ({num_patches}) is not a perfect square")
# Use the closest square dimensions
h = w = int(math.ceil(math.sqrt(num_patches)))
# Pad the attention map to make it square
padded_attention = torch.zeros(h * w, device=attention_to_patches.device)
padded_attention[:num_patches] = attention_to_patches
attention_to_patches = padded_attention
# Reshape to 2D
attention_map = attention_to_patches.reshape(h, w)
# Plot
plt.figure(figsize=(15, 5))
# Original image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
# Attention map
plt.subplot(1, 3, 2)
plt.imshow(attention_map.cpu().numpy(), cmap='hot')
plt.title('Attention Map')
plt.axis('off')
# Overlay
plt.subplot(1, 3, 3)
plt.imshow(image)
plt.imshow(attention_map.cpu().numpy(), alpha=0.5, cmap='hot')
plt.title('Attention Overlay')
plt.axis('off')
plt.tight_layout()
plt.savefig('attention_visualization.png')
plt.close()
print(f"Attention visualization saved as 'attention_visualization.png'")
def check_feature_similarity(self, image_path1, image_path2):
"""Compare embeddings of two images"""
image1 = Image.open(image_path1).convert('RGB')
image2 = Image.open(image_path2).convert('RGB')
with torch.no_grad():
# Get embeddings
emb1 = self.classifier.model.encode_image(
self.transform(image1).unsqueeze(0).to(self.device)
)
emb2 = self.classifier.model.encode_image(
self.transform(image2).unsqueeze(0).to(self.device)
)
# Calculate cosine similarity
similarity = F.cosine_similarity(emb1.mean(dim=1), emb2.mean(dim=1))
# Print statistics
print(f"\nFeature Similarity Analysis:")
print(f"Image 1: {image_path1}")
print(f"Image 2: {image_path2}")
print(f"Cosine Similarity: {similarity.item():.4f}")
print(f"Embedding shapes: {emb1.shape}, {emb2.shape}")
print(f"Embedding means: {emb1.mean().item():.4f}, {emb2.mean().item():.4f}")
print(f"Embedding stds: {emb1.std().item():.4f}, {emb2.std().item():.4f}")
return similarity.item()
def validate_response(self, image_path, diagnosis):
"""Validate if diagnosis contains relevant visual features"""
image = Image.open(image_path).convert('RGB')
# Extract visual features using attention
with torch.no_grad():
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
attention = self.classifier.model.q_former.last_attention[0][0]
# Get regions with high attention
attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1)
high_attention_regions = (attention > attention.mean() + attention.std()).nonzero()
print(f"\nResponse Validation:")
print(f"Image: {image_path}")
print(f"Diagnosis: {diagnosis}")
print(f"Number of high-attention regions: {len(high_attention_regions)}")
return high_attention_regions
def debug_generation(self, image_path, prompt=None):
"""Debug the generation process"""
image = Image.open(image_path).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Get image embeddings
image_embeds = self.classifier.model.encode_image(image_tensor)
print("\nGeneration Debug Information:")
print(f"Image embedding shape: {image_embeds.shape}")
print(f"Image embedding mean: {image_embeds.mean().item():.4f}")
print(f"Image embedding std: {image_embeds.std().item():.4f}")
# Get diagnosis
result = self.classifier.predict(image, user_input=prompt)
print(f"\nGenerated Diagnosis:")
print(result["diagnosis"])
return result
def main():
# Initialize tester
tester = SkinGPTTester()
# Test image paths
test_image = "1.jpg"
similar_image = "2.jpg"
# Run all tests
print("Running comprehensive tests...")
# 1. Visualize attention
print("\n1. Visualizing attention maps...")
tester.visualize_attention(test_image)
# 2. Check feature similarity
print("\n2. Checking feature similarity...")
similarity = tester.check_feature_similarity(test_image, similar_image)
# 3. Debug generation
print("\n3. Debugging generation process...")
result = tester.debug_generation(test_image, "Describe the skin condition in detail.")
# 4. Validate response
print("\n4. Validating response...")
high_attention_regions = tester.validate_response(test_image, result["diagnosis"])
print("\nAll tests completed!")
if __name__ == "__main__":
main()