File size: 11,154 Bytes
c5a3ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""
Minimal example to load and use the Coherence Detection model.
Requires safetensors format with exact architecture.
"""

import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import sys
import os

# ============================================================================
# Import the exact architecture
# ============================================================================
try:
    from model_architecture import load_coherence_model
    print("✓ Imported exact model architecture")
except ImportError as e:
    print(f"Error: model_architecture.py not found or has issues: {e}")
    print("Please download it from the repository.")
    sys.exit(1)

# ============================================================================
# CoherenceClassifier for safetensors
# ============================================================================
class CoherenceClassifier:
    """Loader for coherence detection model (exact architecture)."""
    
    # Categories in alphabetical order (as per training)
    CATEGORIES = ["coherent", "incoherent", "semi-incoherent"]
    
    def __init__(self, model_path, device="auto"):
        """
        Args:
            model_path: Path to .safetensors file
            device: "auto", "cuda", or "cpu"
        """
        self.model_path = model_path
        self.device = self._get_device(device)
        self.model = None
        self.transform = None
        
        # ImageNet normalization (standard for ResNet)
        self.normalize_mean = [0.485, 0.456, 0.406]
        self.normalize_std = [0.229, 0.224, 0.225]
        
        self._setup_transforms()
        self.load_model()
    
    def _get_device(self, device_str):
        """Determine the best available device."""
        if device_str == "auto":
            return torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return torch.device(device_str)
    
    def load_model(self):
        """Load model using safetensors and exact architecture."""
        print(f"Loading coherence model...")
        print(f"  File: {os.path.basename(self.model_path)}")
        print(f"  Device: {self.device}")
        
        # Verify file type
        if not self.model_path.lower().endswith('.safetensors'):
            print("⚠️  Warning: Expected .safetensors file for secure loading.")
        
        # Load using our helper function
        self.model = load_coherence_model(self.model_path, str(self.device))
    
    def _setup_transforms(self):
        """Create image preprocessing pipeline."""
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # ResNet standard size
            transforms.ToTensor(),
            transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)
        ])
    
    def preprocess_image(self, image_path):
        """Load and preprocess an image for the model."""
        try:
            # Open and convert to RGB
            image = Image.open(image_path).convert('RGB')
            
            # Apply transformations
            tensor = self.transform(image)
            
            # Add batch dimension [1, 3, 224, 224]
            tensor = tensor.unsqueeze(0).to(self.device)
            
            return tensor
            
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            raise
    
    def predict(self, image_path, return_probs=False):
        """
        Make a prediction on an image.
        
        Args:
            image_path: Path to the image file
            return_probs: If True, return all probabilities; if False, return only the top category
        
        Returns:
            Dictionary with predictions or string with top category
        """
        if self.model is None:
            raise ValueError("Model not loaded. Call load_model() first.")
        
        # Preprocess the image
        input_tensor = self.preprocess_image(image_path)
        
        # Run inference
        with torch.no_grad():
            output = self.model(input_tensor)
            
            # Convert to probabilities using softmax
            probabilities = torch.nn.functional.softmax(output, dim=1)
            probs = probabilities[0].cpu().numpy()
        
        # Create results dictionary
        results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))}
        
        if return_probs:
            return results
        else:
            # Return the category with highest probability
            top_idx = np.argmax(probs)
            return self.CATEGORIES[top_idx]
    
    def predict_batch(self, image_paths):
        """Make predictions for multiple images."""
        return [self.predict(img_path, return_probs=True) for img_path in image_paths]


# ============================================================================
# 3. Example Usage
# ============================================================================
def main():
    """Example demonstrating how to use the classifier with all three test images."""
    
    # Update this path to your actual model file
    MODEL_PATH = "coherence_model.safetensors"  # Change to your model filename
    
    # All three example images you've provided
    DEMO_IMAGES = [
        "example_coherent.jpg",
        "example_semi_incoherent.jpg", 
        "example_incoherent.jpg"
    ]
    
    # Check if model file exists
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model file not found at {MODEL_PATH}")
        print("Please download the model from the Hugging Face repository:")
        print("https://huggingface.co/your-username/your-model-name")
        return
    
    # Check which demo images exist
    available_images = [img for img in DEMO_IMAGES if os.path.exists(img)]
    
    if not available_images:
        print("Error: No example images found.")
        print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}")
        return
    
    print("=" * 60)
    print("Coherence Detection Model - Complete Demo")
    print("=" * 60)
    
    # Initialize classifier
    print(f"Loading model from: {MODEL_PATH}")
    classifier = CoherenceClassifier(MODEL_PATH, device="auto")
    
    print(f"\nFound {len(available_images)} example image(s) for demonstration.")
    print("-" * 60)
    
    # ========================================================================
    # PART 1: Detailed single image analysis for each available example
    # ========================================================================
    print("\n1. SINGLE IMAGE ANALYSIS")
    print("-" * 40)
    
    for img_file in available_images:
        print(f"\nAnalyzing: {img_file}")
        print("-" * 30)
        
        try:
            # Get full probability distribution
            probs = classifier.predict(img_file, return_probs=True)
            
            # Display all probabilities
            print("Prediction probabilities:")
            for category in classifier.CATEGORIES:
                prob = probs[category]
                # Visual indicator for high confidence (>70%)
                indicator = " ★" if prob > 0.7 else ""
                print(f"  {category:20} {prob:.4f}{indicator}")
            
            # Get and display top category
            top_category = classifier.predict(img_file, return_probs=False)
            top_prob = probs[top_category]
            
            print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})")
            
            # Add interpretation note based on image name
            if img_file == "example_coherent.jpg":
                print("Note: This example should ideally show high 'coherent' probability.")
            elif img_file == "example_semi_incoherent.jpg":
                print("Note: This example demonstrates borderline/partial coherence issues.")
            elif img_file == "example_incoherent.jpg":
                print("Note: This example should show obvious generation failures.")
            
        except Exception as e:
            print(f"Error analyzing {img_file}: {e}")
            continue
    
    # ========================================================================
    # PART 2: Batch prediction comparison (if multiple images available)
    # ========================================================================
    if len(available_images) > 1:
        print("\n" + "=" * 60)
        print("2. BATCH PREDICTION COMPARISON")
        print("-" * 40)
        
        try:
            print(f"Running batch prediction on {len(available_images)} images...")
            batch_results = classifier.predict_batch(available_images)
            
            # Create a comparison table
            print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}")
            print("-" * 65)
            
            for img_path, result in zip(available_images, batch_results):
                top_cat = max(result, key=result.get)
                confidence = result[top_cat]
                
                # Shorten filename if too long
                display_name = os.path.basename(img_path)
                if len(display_name) > 28:
                    display_name = display_name[:25] + "..."
                
                # Color code high confidence predictions
                if confidence > 0.8:
                    confidence_str = f"{confidence:.4f} (HIGH)"
                elif confidence > 0.6:
                    confidence_str = f"{confidence:.4f} (MED)"
                else:
                    confidence_str = f"{confidence:.4f} (LOW)"
                
                print(f"{display_name:30} {top_cat:20} {confidence_str:12}")
            
            print("\nBatch processing complete!")
            
        except Exception as e:
            print(f"Error in batch prediction: {e}")
    
    # ========================================================================
    # PART 3: Quick summary
    # ========================================================================
    print("\n" + "=" * 60)
    print("DEMO SUMMARY")
    print("-" * 40)
    
    print(f"✓ Model loaded successfully on: {classifier.device}")
    print(f"✓ Analyzed {len(available_images)} example image(s)")
    print(f"✓ Example categories: {classifier.CATEGORIES}")
    
    missing_images = [img for img in DEMO_IMAGES if img not in available_images]
    if missing_images:
        print(f"\nNote: Missing example images: {missing_images}")
        print("To complete the demo, add these images to the directory.")
    
    print("\n" + "=" * 60)
    print("Demo completed successfully!")
    print("\nNext steps:")
    print("1. Try your own images by modifying the DEMO_IMAGES list")
    print("2. Use the classifier in your own code:")
    print("   ```python")
    print("   from example_usage import CoherenceClassifier")
    print("   classifier = CoherenceClassifier('your_model.pth')")
    print("   result = classifier.predict('your_image.jpg')")
    print("   ```")


if __name__ == "__main__":
    main()