File size: 7,716 Bytes
167ea92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f46a35d
 
 
 
 
 
 
863dd32
f46a35d
0e15733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f46a35d
 
 
 
 
 
 
 
 
 
 
 
0e15733
f46a35d
 
 
 
 
 
0e15733
f46a35d
 
 
 
 
 
 
 
167ea92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863dd32
167ea92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from PIL import Image
import streamlit as st
from SkinGPT import SkinGPTClassifier
import numpy as np
from torchvision import transforms
import os

class SkinGPTTester:
    def __init__(self, model_path="finetuned_dermnet_version1.pth"):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.classifier = SkinGPTClassifier()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def visualize_attention(self, image_path):
        """Visualize attention maps from Q-Former"""
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            # Get attention maps
            _ = self.classifier.model.encode_image(image_tensor)
            
            # Get attention from Q-Former
            if self.classifier.model.q_former.last_attention is None:
                print("Warning: No attention maps available. Make sure output_attentions=True in BERT config.")
                return
            
            # Get the last layer's attention
            attention = self.classifier.model.q_former.last_attention[0][0]  # shape: [num_tokens,]
            
            # Print attention shape for debugging
            print(f"Attention shape: {attention.shape}")
            
            # Reshape attention to match image dimensions
            # The attention shape should be [num_query_tokens + num_patches, num_query_tokens + num_patches]
            # We want to visualize the attention from query tokens to image patches
            num_query_tokens = self.classifier.model.q_former.num_query_tokens
            attention_to_patches = attention[num_query_tokens:, :num_query_tokens].mean(dim=1)
            
            # Calculate the number of patches
            num_patches = attention_to_patches.shape[0]
            h = w = int(math.sqrt(num_patches))
            
            if h * w != num_patches:
                print(f"Warning: Number of patches ({num_patches}) is not a perfect square")
                # Use the closest square dimensions
                h = w = int(math.ceil(math.sqrt(num_patches)))
                # Pad the attention map to make it square
                padded_attention = torch.zeros(h * w, device=attention_to_patches.device)
                padded_attention[:num_patches] = attention_to_patches
                attention_to_patches = padded_attention
            
            # Reshape to 2D
            attention_map = attention_to_patches.reshape(h, w)
            
            # Plot
            plt.figure(figsize=(15, 5))
            
            # Original image
            plt.subplot(1, 3, 1)
            plt.imshow(image)
            plt.title('Original Image')
            plt.axis('off')
            
            # Attention map
            plt.subplot(1, 3, 2)
            plt.imshow(attention_map.cpu().numpy(), cmap='hot')
            plt.title('Attention Map')
            plt.axis('off')
            
            # Overlay
            plt.subplot(1, 3, 3)
            plt.imshow(image)
            plt.imshow(attention_map.cpu().numpy(), alpha=0.5, cmap='hot')
            plt.title('Attention Overlay')
            plt.axis('off')
            
            plt.tight_layout()
            plt.savefig('attention_visualization.png')
            plt.close()
            
            print(f"Attention visualization saved as 'attention_visualization.png'")

    def check_feature_similarity(self, image_path1, image_path2):
        """Compare embeddings of two images"""
        image1 = Image.open(image_path1).convert('RGB')
        image2 = Image.open(image_path2).convert('RGB')
        
        with torch.no_grad():
            # Get embeddings
            emb1 = self.classifier.model.encode_image(
                self.transform(image1).unsqueeze(0).to(self.device)
            )
            emb2 = self.classifier.model.encode_image(
                self.transform(image2).unsqueeze(0).to(self.device)
            )
            
            # Calculate cosine similarity
            similarity = F.cosine_similarity(emb1.mean(dim=1), emb2.mean(dim=1))
            
            # Print statistics
            print(f"\nFeature Similarity Analysis:")
            print(f"Image 1: {image_path1}")
            print(f"Image 2: {image_path2}")
            print(f"Cosine Similarity: {similarity.item():.4f}")
            print(f"Embedding shapes: {emb1.shape}, {emb2.shape}")
            print(f"Embedding means: {emb1.mean().item():.4f}, {emb2.mean().item():.4f}")
            print(f"Embedding stds: {emb1.std().item():.4f}, {emb2.std().item():.4f}")
            
            return similarity.item()

    def validate_response(self, image_path, diagnosis):
        """Validate if diagnosis contains relevant visual features"""
        image = Image.open(image_path).convert('RGB')
        
        # Extract visual features using attention
        with torch.no_grad():
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
            attention = self.classifier.model.q_former.last_attention[0][0]
        
        # Get regions with high attention
        attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1)
        high_attention_regions = (attention > attention.mean() + attention.std()).nonzero()
        
        print(f"\nResponse Validation:")
        print(f"Image: {image_path}")
        print(f"Diagnosis: {diagnosis}")
        print(f"Number of high-attention regions: {len(high_attention_regions)}")
        
        return high_attention_regions

    def debug_generation(self, image_path, prompt=None):
        """Debug the generation process"""
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            # Get image embeddings
            image_embeds = self.classifier.model.encode_image(image_tensor)
            
            print("\nGeneration Debug Information:")
            print(f"Image embedding shape: {image_embeds.shape}")
            print(f"Image embedding mean: {image_embeds.mean().item():.4f}")
            print(f"Image embedding std: {image_embeds.std().item():.4f}")
            
            # Get diagnosis
            result = self.classifier.predict(image, user_input=prompt)
            
            print(f"\nGenerated Diagnosis:")
            print(result["diagnosis"])
            
            return result

def main():
    # Initialize tester
    tester = SkinGPTTester()
    
    # Test image paths
    test_image = "1.jpg"
    similar_image = "2.jpg"
    
    # Run all tests
    print("Running comprehensive tests...")
    
    # 1. Visualize attention
    print("\n1. Visualizing attention maps...")
    tester.visualize_attention(test_image)
    
    # 2. Check feature similarity
    print("\n2. Checking feature similarity...")
    similarity = tester.check_feature_similarity(test_image, similar_image)
    
    # 3. Debug generation
    print("\n3. Debugging generation process...")
    result = tester.debug_generation(test_image, "Describe the skin condition in detail.")
    
    # 4. Validate response
    print("\n4. Validating response...")
    high_attention_regions = tester.validate_response(test_image, result["diagnosis"])
    
    print("\nAll tests completed!")

if __name__ == "__main__":
    main()