|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import matplotlib.pyplot as plt |
|
|
import math |
|
|
from PIL import Image |
|
|
import streamlit as st |
|
|
from SkinGPT import SkinGPTClassifier |
|
|
import numpy as np |
|
|
from torchvision import transforms |
|
|
import os |
|
|
|
|
|
class SkinGPTTester: |
|
|
def __init__(self, model_path="finetuned_dermnet_version1.pth"): |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
self.classifier = SkinGPTClassifier() |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def visualize_attention(self, image_path): |
|
|
"""Visualize attention maps from Q-Former""" |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
_ = self.classifier.model.encode_image(image_tensor) |
|
|
|
|
|
|
|
|
if self.classifier.model.q_former.last_attention is None: |
|
|
print("Warning: No attention maps available. Make sure output_attentions=True in BERT config.") |
|
|
return |
|
|
|
|
|
|
|
|
attention = self.classifier.model.q_former.last_attention[0][0] |
|
|
|
|
|
|
|
|
print(f"Attention shape: {attention.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_query_tokens = self.classifier.model.q_former.num_query_tokens |
|
|
attention_to_patches = attention[num_query_tokens:, :num_query_tokens].mean(dim=1) |
|
|
|
|
|
|
|
|
num_patches = attention_to_patches.shape[0] |
|
|
h = w = int(math.sqrt(num_patches)) |
|
|
|
|
|
if h * w != num_patches: |
|
|
print(f"Warning: Number of patches ({num_patches}) is not a perfect square") |
|
|
|
|
|
h = w = int(math.ceil(math.sqrt(num_patches))) |
|
|
|
|
|
padded_attention = torch.zeros(h * w, device=attention_to_patches.device) |
|
|
padded_attention[:num_patches] = attention_to_patches |
|
|
attention_to_patches = padded_attention |
|
|
|
|
|
|
|
|
attention_map = attention_to_patches.reshape(h, w) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(15, 5)) |
|
|
|
|
|
|
|
|
plt.subplot(1, 3, 1) |
|
|
plt.imshow(image) |
|
|
plt.title('Original Image') |
|
|
plt.axis('off') |
|
|
|
|
|
|
|
|
plt.subplot(1, 3, 2) |
|
|
plt.imshow(attention_map.cpu().numpy(), cmap='hot') |
|
|
plt.title('Attention Map') |
|
|
plt.axis('off') |
|
|
|
|
|
|
|
|
plt.subplot(1, 3, 3) |
|
|
plt.imshow(image) |
|
|
plt.imshow(attention_map.cpu().numpy(), alpha=0.5, cmap='hot') |
|
|
plt.title('Attention Overlay') |
|
|
plt.axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('attention_visualization.png') |
|
|
plt.close() |
|
|
|
|
|
print(f"Attention visualization saved as 'attention_visualization.png'") |
|
|
|
|
|
def check_feature_similarity(self, image_path1, image_path2): |
|
|
"""Compare embeddings of two images""" |
|
|
image1 = Image.open(image_path1).convert('RGB') |
|
|
image2 = Image.open(image_path2).convert('RGB') |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
emb1 = self.classifier.model.encode_image( |
|
|
self.transform(image1).unsqueeze(0).to(self.device) |
|
|
) |
|
|
emb2 = self.classifier.model.encode_image( |
|
|
self.transform(image2).unsqueeze(0).to(self.device) |
|
|
) |
|
|
|
|
|
|
|
|
similarity = F.cosine_similarity(emb1.mean(dim=1), emb2.mean(dim=1)) |
|
|
|
|
|
|
|
|
print(f"\nFeature Similarity Analysis:") |
|
|
print(f"Image 1: {image_path1}") |
|
|
print(f"Image 2: {image_path2}") |
|
|
print(f"Cosine Similarity: {similarity.item():.4f}") |
|
|
print(f"Embedding shapes: {emb1.shape}, {emb2.shape}") |
|
|
print(f"Embedding means: {emb1.mean().item():.4f}, {emb2.mean().item():.4f}") |
|
|
print(f"Embedding stds: {emb1.std().item():.4f}, {emb2.std().item():.4f}") |
|
|
|
|
|
return similarity.item() |
|
|
|
|
|
def validate_response(self, image_path, diagnosis): |
|
|
"""Validate if diagnosis contains relevant visual features""" |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
attention = self.classifier.model.q_former.last_attention[0][0] |
|
|
|
|
|
|
|
|
attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1) |
|
|
high_attention_regions = (attention > attention.mean() + attention.std()).nonzero() |
|
|
|
|
|
print(f"\nResponse Validation:") |
|
|
print(f"Image: {image_path}") |
|
|
print(f"Diagnosis: {diagnosis}") |
|
|
print(f"Number of high-attention regions: {len(high_attention_regions)}") |
|
|
|
|
|
return high_attention_regions |
|
|
|
|
|
def debug_generation(self, image_path, prompt=None): |
|
|
"""Debug the generation process""" |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
image_embeds = self.classifier.model.encode_image(image_tensor) |
|
|
|
|
|
print("\nGeneration Debug Information:") |
|
|
print(f"Image embedding shape: {image_embeds.shape}") |
|
|
print(f"Image embedding mean: {image_embeds.mean().item():.4f}") |
|
|
print(f"Image embedding std: {image_embeds.std().item():.4f}") |
|
|
|
|
|
|
|
|
result = self.classifier.predict(image, user_input=prompt) |
|
|
|
|
|
print(f"\nGenerated Diagnosis:") |
|
|
print(result["diagnosis"]) |
|
|
|
|
|
return result |
|
|
|
|
|
def main(): |
|
|
|
|
|
tester = SkinGPTTester() |
|
|
|
|
|
|
|
|
test_image = "1.jpg" |
|
|
similar_image = "2.jpg" |
|
|
|
|
|
|
|
|
print("Running comprehensive tests...") |
|
|
|
|
|
|
|
|
print("\n1. Visualizing attention maps...") |
|
|
tester.visualize_attention(test_image) |
|
|
|
|
|
|
|
|
print("\n2. Checking feature similarity...") |
|
|
similarity = tester.check_feature_similarity(test_image, similar_image) |
|
|
|
|
|
|
|
|
print("\n3. Debugging generation process...") |
|
|
result = tester.debug_generation(test_image, "Describe the skin condition in detail.") |
|
|
|
|
|
|
|
|
print("\n4. Validating response...") |
|
|
high_attention_regions = tester.validate_response(test_image, result["diagnosis"]) |
|
|
|
|
|
print("\nAll tests completed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |