reddesert commited on
Commit
c5a3ef9
·
1 Parent(s): 1ea77e1

Convert to safetensors format and add model architecture

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -18,10 +18,10 @@ A ResNet-34 fine-tuned on a personally-curated dataset to classify images into o
18
  - Incoherent
19
  - Semi-Incoherent
20
 
21
- **Key Feature**: This model includes a production-ready loading wrapper that handles FastAI's `AdaptiveConcatPool2d` and safe/unsafe weight loading.
22
 
23
  ## Installation and Usage
24
- **Important**: This model requires custom loading code due to FastAI architecture.
25
 
26
  1. **Install**
27
  ```bash
@@ -38,7 +38,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
38
  pip install -r requirements.txt
39
  ```
40
 
41
- 2. **Download and run the example**:
42
  ```bash
43
  python example_usage.py
44
  ```
@@ -47,16 +47,13 @@ python example_usage.py
47
  ```python
48
  from example_usage import CoherenceClassifier
49
  # Initialize with your model
50
- classifier = CoherenceClassifier("coherence_model.pth")
51
  # Predict on an image
52
  result = classifier.predict("your_image.jpg", return_probs=True)
53
  print(result) # {'coherent': 0.85, 'incoherent': 0.05, 'semi-incoherent': 0.10}
54
  ```
55
 
56
- Note on `weights_only=False`: This model contains FastAI's AdaptiveConcatPool2d layer. You must either:
57
-
58
- - Use `weights_only=False` (only if you trust the source), OR
59
- - Register `AdaptiveConcatPool2d` in `__main__` before loading (see example_usage.py).
60
 
61
  ## Model Architecture
62
  Backbone: ResNet-34 (via FastAI's default)
@@ -68,7 +65,7 @@ For version 1.0, a small dataset was used of ~20k images in coherent category, a
68
 
69
  ## Limitations
70
  - PyTorch environment
71
- - `weights_only=False` loading has security implications; use only with trusted models.
72
 
73
  Note especially that while an attempt was made at classifying the less obvious but definitely noticeable generation failures like the following into the "semi-incoherent" category, these are much harder to detect and in the current iteration of this model there is no expectation that they will be observed in any particular coherence category, especially when these particular errors occur in a small part of the image.
74
  - Extra or missing limbs, fingers, or facial features
@@ -83,7 +80,7 @@ Note especially that while an attempt was made at classifying the less obvious b
83
  ## Comment
84
  Given the low coherence rate of results produced by early image generation models, it was very surprising that a model was not found for this purpose, necessitating the creation of this one for high-volume review scenarios.
85
 
86
- Perhaps models such as this one are avoided or seen as improper due to the perceived danger they pose in introduction of bias to image analysis, however it is highly likely that image generators would rather have at least some bias towards coherence and a somewhat clear mind when reviewing their image output than no bias and a mind littered with the psychologically damaging results of obviously-failed generations which have little to do with the prompter's intent.
87
 
88
  ## Model Card Authors
89
  Tom Hall
 
18
  - Incoherent
19
  - Semi-Incoherent
20
 
21
+ **Key Feature**: This model is provided in safetensors format with a production-ready loading wrapper (`model_architecture.py`) that handles FastAI's `AdaptiveConcatPool2d` layer automatically.
22
 
23
  ## Installation and Usage
24
+ **Important**: This model is provided in safetensors format and requires the `model_architecture.py` module for proper loading.
25
 
26
  1. **Install**
27
  ```bash
 
38
  pip install -r requirements.txt
39
  ```
40
 
41
+ 2. **Download entire repo and run the example**:
42
  ```bash
43
  python example_usage.py
44
  ```
 
47
  ```python
48
  from example_usage import CoherenceClassifier
49
  # Initialize with your model
50
+ classifier = CoherenceClassifier("coherence_model.safetensors")
51
  # Predict on an image
52
  result = classifier.predict("your_image.jpg", return_probs=True)
53
  print(result) # {'coherent': 0.85, 'incoherent': 0.05, 'semi-incoherent': 0.10}
54
  ```
55
 
56
+ **Note**: The model uses FastAI's `AdaptiveConcatPool2d` layer. Import and use `model_architecture.py` which handles this automatically. The `example_usage.py` script demonstrates the proper import pattern.
 
 
 
57
 
58
  ## Model Architecture
59
  Backbone: ResNet-34 (via FastAI's default)
 
65
 
66
  ## Limitations
67
  - PyTorch environment
68
+ - Requires `model_architecture.py` module for proper loading (handles FastAI-specific layers automatically)
69
 
70
  Note especially that while an attempt was made at classifying the less obvious but definitely noticeable generation failures like the following into the "semi-incoherent" category, these are much harder to detect and in the current iteration of this model there is no expectation that they will be observed in any particular coherence category, especially when these particular errors occur in a small part of the image.
71
  - Extra or missing limbs, fingers, or facial features
 
80
  ## Comment
81
  Given the low coherence rate of results produced by early image generation models, it was very surprising that a model was not found for this purpose, necessitating the creation of this one for high-volume review scenarios.
82
 
83
+ Perhaps models such as this one are avoided or seen as improper due to the perceived danger they pose in introduction of bias to image analysis, however it is highly likely that image generators would rather have at least some bias towards coherence and a somewhat clear mind when reviewing their image output than no bias and a mind littered with the psychologically-damaging results of obviously-failed generations which have little to do with the prompter's intent.
84
 
85
  ## Model Card Authors
86
  Tom Hall
coherence_model.pth → coherence_model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9086f65a0853e6c75751c1ae552ecf2ad213e7af92d72619aeee17a31e4c9be6
3
- size 87437751
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2f937854cc4ae58ebbe2e6b451006a5ee8dc65fabe5c32d33ef937cae24724b
3
+ size 87356700
example_usage.py CHANGED
@@ -1,326 +1,287 @@
1
- """
2
- Minimal example to load and use the FastAI-based Coherence Detection model.
3
- """
4
-
5
- import torch
6
- import torch.nn as nn
7
- from torchvision import transforms
8
- from PIL import Image
9
- import numpy as np
10
- import sys
11
- import os
12
-
13
- # ============================================================================
14
- # 1. CRITICAL: Register AdaptiveConcatPool2d for FastAI compatibility
15
- # ============================================================================
16
- def _register_adaptive_concat_pool():
17
- """Register AdaptiveConcatPool2d in __main__ for unpickling compatibility"""
18
- # Define the class
19
- class AdaptiveConcatPool2d(nn.Module):
20
- def __init__(self, sz=None):
21
- super().__init__()
22
- self.ap = nn.AdaptiveAvgPool2d(sz or 1)
23
- self.mp = nn.AdaptiveMaxPool2d(sz or 1)
24
-
25
- def forward(self, x):
26
- return torch.cat([self.mp(x), self.ap(x)], 1)
27
-
28
- # Register in the current module
29
- current_module = sys.modules[__name__]
30
- if not hasattr(current_module, 'AdaptiveConcatPool2d'):
31
- setattr(current_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d)
32
-
33
- # Also register in __main__ if it exists (required for unpickling)
34
- if '__main__' in sys.modules:
35
- main_module = sys.modules['__main__']
36
- if not hasattr(main_module, 'AdaptiveConcatPool2d'):
37
- setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d)
38
-
39
- return AdaptiveConcatPool2d
40
-
41
- # Execute registration immediately
42
- AdaptiveConcatPool2d = _register_adaptive_concat_pool()
43
-
44
-
45
- # ============================================================================
46
- # 2. Simplified Model Loader
47
- # ============================================================================
48
- class CoherenceClassifier:
49
- """Simple loader for the FastAI-based coherence detection model."""
50
-
51
- # Model categories (update if your model uses different names)
52
- CATEGORIES = ["coherent", "incoherent", "semi-incoherent"]
53
-
54
- def __init__(self, model_path, device="auto"):
55
- """
56
- Args:
57
- model_path: Path to the .pth model file
58
- device: "auto", "cuda", or "cpu"
59
- """
60
- self.model_path = model_path
61
- self.device = self._get_device(device)
62
- self.model = None
63
- self.transform = None
64
-
65
- # ImageNet normalization (standard for ResNet)
66
- self.normalize_mean = [0.485, 0.456, 0.406]
67
- self.normalize_std = [0.229, 0.224, 0.225]
68
-
69
- self.load_model()
70
-
71
- def _get_device(self, device_str):
72
- """Determine the best available device."""
73
- if device_str == "auto":
74
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
- return torch.device(device_str)
76
-
77
- def load_model(self):
78
- """Load the FastAI model with proper handling of AdaptiveConcatPool2d."""
79
- print(f"Loading model from {self.model_path}...")
80
- print(f"Using device: {self.device}")
81
-
82
- # IMPORTANT: weights_only=False is REQUIRED for FastAI models
83
- # because they contain custom layers like AdaptiveConcatPool2d
84
- print("Note: Using weights_only=False for FastAI compatibility")
85
-
86
- try:
87
- # Load the entire model (not just state_dict)
88
- self.model = torch.load(
89
- self.model_path,
90
- map_location=self.device,
91
- weights_only=False # Must be False for FastAI models
92
- )
93
-
94
- # Ensure model is in evaluation mode
95
- self.model.eval()
96
- print("Model loaded successfully!")
97
-
98
- except Exception as e:
99
- print(f"Error loading model: {e}")
100
- print("\nTroubleshooting tips:")
101
- print("1. Ensure AdaptiveConcatPool2d is registered before loading")
102
- print("2. Make sure the model file is not corrupted")
103
- print("3. Verify PyTorch and torchvision are installed")
104
- raise
105
-
106
- # Setup image transformations
107
- self._setup_transforms()
108
-
109
- def _setup_transforms(self):
110
- """Create image preprocessing pipeline."""
111
- self.transform = transforms.Compose([
112
- transforms.Resize((224, 224)), # ResNet standard size
113
- transforms.ToTensor(),
114
- transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)
115
- ])
116
-
117
- def preprocess_image(self, image_path):
118
- """Load and preprocess an image for the model."""
119
- try:
120
- # Open and convert to RGB
121
- image = Image.open(image_path).convert('RGB')
122
-
123
- # Apply transformations
124
- tensor = self.transform(image)
125
-
126
- # Add batch dimension [1, 3, 224, 224]
127
- tensor = tensor.unsqueeze(0).to(self.device)
128
-
129
- return tensor
130
-
131
- except Exception as e:
132
- print(f"Error processing image {image_path}: {e}")
133
- raise
134
-
135
- def predict(self, image_path, return_probs=False):
136
- """
137
- Make a prediction on an image.
138
-
139
- Args:
140
- image_path: Path to the image file
141
- return_probs: If True, return all probabilities; if False, return only the top category
142
-
143
- Returns:
144
- Dictionary with predictions or string with top category
145
- """
146
- if self.model is None:
147
- raise ValueError("Model not loaded. Call load_model() first.")
148
-
149
- # Preprocess the image
150
- input_tensor = self.preprocess_image(image_path)
151
-
152
- # Run inference
153
- with torch.no_grad():
154
- output = self.model(input_tensor)
155
-
156
- # Convert to probabilities using softmax
157
- probabilities = torch.nn.functional.softmax(output, dim=1)
158
- probs = probabilities[0].cpu().numpy()
159
-
160
- # Create results dictionary
161
- results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))}
162
-
163
- if return_probs:
164
- return results
165
- else:
166
- # Return the category with highest probability
167
- top_idx = np.argmax(probs)
168
- return self.CATEGORIES[top_idx]
169
-
170
- def predict_batch(self, image_paths):
171
- """Make predictions for multiple images."""
172
- return [self.predict(img_path, return_probs=True) for img_path in image_paths]
173
-
174
-
175
- # ============================================================================
176
- # 3. Example Usage
177
- # ============================================================================
178
- def main():
179
- """Example demonstrating how to use the classifier with all three test images."""
180
-
181
- # Update this path to your actual model file
182
- MODEL_PATH = "coherence_model.pth" # Change to your model filename
183
-
184
- # All three example images you've provided
185
- DEMO_IMAGES = [
186
- "example_coherent.jpg",
187
- "example_semi_incoherent.jpg",
188
- "example_incoherent.jpg"
189
- ]
190
-
191
- # Check if model file exists
192
- if not os.path.exists(MODEL_PATH):
193
- print(f"Error: Model file not found at {MODEL_PATH}")
194
- print("Please download the model from the Hugging Face repository:")
195
- print("https://huggingface.co/your-username/your-model-name")
196
- return
197
-
198
- # Check which demo images exist
199
- available_images = [img for img in DEMO_IMAGES if os.path.exists(img)]
200
-
201
- if not available_images:
202
- print("Error: No example images found.")
203
- print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}")
204
- return
205
-
206
- print("=" * 60)
207
- print("Coherence Detection Model - Complete Demo")
208
- print("=" * 60)
209
-
210
- # Initialize classifier
211
- print(f"Loading model from: {MODEL_PATH}")
212
- classifier = CoherenceClassifier(MODEL_PATH, device="auto")
213
-
214
- print(f"\nFound {len(available_images)} example image(s) for demonstration.")
215
- print("-" * 60)
216
-
217
- # ========================================================================
218
- # PART 1: Detailed single image analysis for each available example
219
- # ========================================================================
220
- print("\n1. SINGLE IMAGE ANALYSIS")
221
- print("-" * 40)
222
-
223
- for img_file in available_images:
224
- print(f"\nAnalyzing: {img_file}")
225
- print("-" * 30)
226
-
227
- try:
228
- # Get full probability distribution
229
- probs = classifier.predict(img_file, return_probs=True)
230
-
231
- # Display all probabilities
232
- print("Prediction probabilities:")
233
- for category in classifier.CATEGORIES:
234
- prob = probs[category]
235
- # Visual indicator for high confidence (>70%)
236
- indicator = " ★" if prob > 0.7 else ""
237
- print(f" {category:20} {prob:.4f}{indicator}")
238
-
239
- # Get and display top category
240
- top_category = classifier.predict(img_file, return_probs=False)
241
- top_prob = probs[top_category]
242
-
243
- print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})")
244
-
245
- # Add interpretation note based on image name
246
- if "coherent" in img_file:
247
- print("Note: This example should ideally show high 'coherent' probability.")
248
- elif "semi_incoherent" in img_file:
249
- print("Note: This example demonstrates borderline/partial coherence issues.")
250
- elif "incoherent" in img_file:
251
- print("Note: This example should show obvious generation failures.")
252
-
253
- except Exception as e:
254
- print(f"Error analyzing {img_file}: {e}")
255
- continue
256
-
257
- # ========================================================================
258
- # PART 2: Batch prediction comparison (if multiple images available)
259
- # ========================================================================
260
- if len(available_images) > 1:
261
- print("\n" + "=" * 60)
262
- print("2. BATCH PREDICTION COMPARISON")
263
- print("-" * 40)
264
-
265
- try:
266
- print(f"Running batch prediction on {len(available_images)} images...")
267
- batch_results = classifier.predict_batch(available_images)
268
-
269
- # Create a comparison table
270
- print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}")
271
- print("-" * 65)
272
-
273
- for img_path, result in zip(available_images, batch_results):
274
- top_cat = max(result, key=result.get)
275
- confidence = result[top_cat]
276
-
277
- # Shorten filename if too long
278
- display_name = os.path.basename(img_path)
279
- if len(display_name) > 28:
280
- display_name = display_name[:25] + "..."
281
-
282
- # Color code high confidence predictions
283
- if confidence > 0.8:
284
- confidence_str = f"{confidence:.4f} (HIGH)"
285
- elif confidence > 0.6:
286
- confidence_str = f"{confidence:.4f} (MED)"
287
- else:
288
- confidence_str = f"{confidence:.4f} (LOW)"
289
-
290
- print(f"{display_name:30} {top_cat:20} {confidence_str:12}")
291
-
292
- print("\nBatch processing complete!")
293
-
294
- except Exception as e:
295
- print(f"Error in batch prediction: {e}")
296
-
297
- # ========================================================================
298
- # PART 3: Quick summary
299
- # ========================================================================
300
- print("\n" + "=" * 60)
301
- print("DEMO SUMMARY")
302
- print("-" * 40)
303
-
304
- print(f"✓ Model loaded successfully on: {classifier.device}")
305
- print(f"✓ Analyzed {len(available_images)} example image(s)")
306
- print(f"✓ Example categories: {classifier.CATEGORIES}")
307
-
308
- missing_images = [img for img in DEMO_IMAGES if img not in available_images]
309
- if missing_images:
310
- print(f"\nNote: Missing example images: {missing_images}")
311
- print("To complete the demo, add these images to the directory.")
312
-
313
- print("\n" + "=" * 60)
314
- print("Demo completed successfully!")
315
- print("\nNext steps:")
316
- print("1. Try your own images by modifying the DEMO_IMAGES list")
317
- print("2. Use the classifier in your own code:")
318
- print(" ```python")
319
- print(" from example_usage import CoherenceClassifier")
320
- print(" classifier = CoherenceClassifier('your_model.pth')")
321
- print(" result = classifier.predict('your_image.jpg')")
322
- print(" ```")
323
-
324
-
325
- if __name__ == "__main__":
326
- main()
 
1
+ """
2
+ Minimal example to load and use the Coherence Detection model.
3
+ Requires safetensors format with exact architecture.
4
+ """
5
+
6
+ import torch
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import numpy as np
10
+ import sys
11
+ import os
12
+
13
+ # ============================================================================
14
+ # Import the exact architecture
15
+ # ============================================================================
16
+ try:
17
+ from model_architecture import load_coherence_model
18
+ print("✓ Imported exact model architecture")
19
+ except ImportError as e:
20
+ print(f"Error: model_architecture.py not found or has issues: {e}")
21
+ print("Please download it from the repository.")
22
+ sys.exit(1)
23
+
24
+ # ============================================================================
25
+ # CoherenceClassifier for safetensors
26
+ # ============================================================================
27
+ class CoherenceClassifier:
28
+ """Loader for coherence detection model (exact architecture)."""
29
+
30
+ # Categories in alphabetical order (as per training)
31
+ CATEGORIES = ["coherent", "incoherent", "semi-incoherent"]
32
+
33
+ def __init__(self, model_path, device="auto"):
34
+ """
35
+ Args:
36
+ model_path: Path to .safetensors file
37
+ device: "auto", "cuda", or "cpu"
38
+ """
39
+ self.model_path = model_path
40
+ self.device = self._get_device(device)
41
+ self.model = None
42
+ self.transform = None
43
+
44
+ # ImageNet normalization (standard for ResNet)
45
+ self.normalize_mean = [0.485, 0.456, 0.406]
46
+ self.normalize_std = [0.229, 0.224, 0.225]
47
+
48
+ self._setup_transforms()
49
+ self.load_model()
50
+
51
+ def _get_device(self, device_str):
52
+ """Determine the best available device."""
53
+ if device_str == "auto":
54
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ return torch.device(device_str)
56
+
57
+ def load_model(self):
58
+ """Load model using safetensors and exact architecture."""
59
+ print(f"Loading coherence model...")
60
+ print(f" File: {os.path.basename(self.model_path)}")
61
+ print(f" Device: {self.device}")
62
+
63
+ # Verify file type
64
+ if not self.model_path.lower().endswith('.safetensors'):
65
+ print("⚠️ Warning: Expected .safetensors file for secure loading.")
66
+
67
+ # Load using our helper function
68
+ self.model = load_coherence_model(self.model_path, str(self.device))
69
+
70
+ def _setup_transforms(self):
71
+ """Create image preprocessing pipeline."""
72
+ self.transform = transforms.Compose([
73
+ transforms.Resize((224, 224)), # ResNet standard size
74
+ transforms.ToTensor(),
75
+ transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)
76
+ ])
77
+
78
+ def preprocess_image(self, image_path):
79
+ """Load and preprocess an image for the model."""
80
+ try:
81
+ # Open and convert to RGB
82
+ image = Image.open(image_path).convert('RGB')
83
+
84
+ # Apply transformations
85
+ tensor = self.transform(image)
86
+
87
+ # Add batch dimension [1, 3, 224, 224]
88
+ tensor = tensor.unsqueeze(0).to(self.device)
89
+
90
+ return tensor
91
+
92
+ except Exception as e:
93
+ print(f"Error processing image {image_path}: {e}")
94
+ raise
95
+
96
+ def predict(self, image_path, return_probs=False):
97
+ """
98
+ Make a prediction on an image.
99
+
100
+ Args:
101
+ image_path: Path to the image file
102
+ return_probs: If True, return all probabilities; if False, return only the top category
103
+
104
+ Returns:
105
+ Dictionary with predictions or string with top category
106
+ """
107
+ if self.model is None:
108
+ raise ValueError("Model not loaded. Call load_model() first.")
109
+
110
+ # Preprocess the image
111
+ input_tensor = self.preprocess_image(image_path)
112
+
113
+ # Run inference
114
+ with torch.no_grad():
115
+ output = self.model(input_tensor)
116
+
117
+ # Convert to probabilities using softmax
118
+ probabilities = torch.nn.functional.softmax(output, dim=1)
119
+ probs = probabilities[0].cpu().numpy()
120
+
121
+ # Create results dictionary
122
+ results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))}
123
+
124
+ if return_probs:
125
+ return results
126
+ else:
127
+ # Return the category with highest probability
128
+ top_idx = np.argmax(probs)
129
+ return self.CATEGORIES[top_idx]
130
+
131
+ def predict_batch(self, image_paths):
132
+ """Make predictions for multiple images."""
133
+ return [self.predict(img_path, return_probs=True) for img_path in image_paths]
134
+
135
+
136
+ # ============================================================================
137
+ # 3. Example Usage
138
+ # ============================================================================
139
+ def main():
140
+ """Example demonstrating how to use the classifier with all three test images."""
141
+
142
+ # Update this path to your actual model file
143
+ MODEL_PATH = "coherence_model.safetensors" # Change to your model filename
144
+
145
+ # All three example images you've provided
146
+ DEMO_IMAGES = [
147
+ "example_coherent.jpg",
148
+ "example_semi_incoherent.jpg",
149
+ "example_incoherent.jpg"
150
+ ]
151
+
152
+ # Check if model file exists
153
+ if not os.path.exists(MODEL_PATH):
154
+ print(f"Error: Model file not found at {MODEL_PATH}")
155
+ print("Please download the model from the Hugging Face repository:")
156
+ print("https://huggingface.co/your-username/your-model-name")
157
+ return
158
+
159
+ # Check which demo images exist
160
+ available_images = [img for img in DEMO_IMAGES if os.path.exists(img)]
161
+
162
+ if not available_images:
163
+ print("Error: No example images found.")
164
+ print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}")
165
+ return
166
+
167
+ print("=" * 60)
168
+ print("Coherence Detection Model - Complete Demo")
169
+ print("=" * 60)
170
+
171
+ # Initialize classifier
172
+ print(f"Loading model from: {MODEL_PATH}")
173
+ classifier = CoherenceClassifier(MODEL_PATH, device="auto")
174
+
175
+ print(f"\nFound {len(available_images)} example image(s) for demonstration.")
176
+ print("-" * 60)
177
+
178
+ # ========================================================================
179
+ # PART 1: Detailed single image analysis for each available example
180
+ # ========================================================================
181
+ print("\n1. SINGLE IMAGE ANALYSIS")
182
+ print("-" * 40)
183
+
184
+ for img_file in available_images:
185
+ print(f"\nAnalyzing: {img_file}")
186
+ print("-" * 30)
187
+
188
+ try:
189
+ # Get full probability distribution
190
+ probs = classifier.predict(img_file, return_probs=True)
191
+
192
+ # Display all probabilities
193
+ print("Prediction probabilities:")
194
+ for category in classifier.CATEGORIES:
195
+ prob = probs[category]
196
+ # Visual indicator for high confidence (>70%)
197
+ indicator = " ★" if prob > 0.7 else ""
198
+ print(f" {category:20} {prob:.4f}{indicator}")
199
+
200
+ # Get and display top category
201
+ top_category = classifier.predict(img_file, return_probs=False)
202
+ top_prob = probs[top_category]
203
+
204
+ print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})")
205
+
206
+ # Add interpretation note based on image name
207
+ if img_file == "example_coherent.jpg":
208
+ print("Note: This example should ideally show high 'coherent' probability.")
209
+ elif img_file == "example_semi_incoherent.jpg":
210
+ print("Note: This example demonstrates borderline/partial coherence issues.")
211
+ elif img_file == "example_incoherent.jpg":
212
+ print("Note: This example should show obvious generation failures.")
213
+
214
+ except Exception as e:
215
+ print(f"Error analyzing {img_file}: {e}")
216
+ continue
217
+
218
+ # ========================================================================
219
+ # PART 2: Batch prediction comparison (if multiple images available)
220
+ # ========================================================================
221
+ if len(available_images) > 1:
222
+ print("\n" + "=" * 60)
223
+ print("2. BATCH PREDICTION COMPARISON")
224
+ print("-" * 40)
225
+
226
+ try:
227
+ print(f"Running batch prediction on {len(available_images)} images...")
228
+ batch_results = classifier.predict_batch(available_images)
229
+
230
+ # Create a comparison table
231
+ print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}")
232
+ print("-" * 65)
233
+
234
+ for img_path, result in zip(available_images, batch_results):
235
+ top_cat = max(result, key=result.get)
236
+ confidence = result[top_cat]
237
+
238
+ # Shorten filename if too long
239
+ display_name = os.path.basename(img_path)
240
+ if len(display_name) > 28:
241
+ display_name = display_name[:25] + "..."
242
+
243
+ # Color code high confidence predictions
244
+ if confidence > 0.8:
245
+ confidence_str = f"{confidence:.4f} (HIGH)"
246
+ elif confidence > 0.6:
247
+ confidence_str = f"{confidence:.4f} (MED)"
248
+ else:
249
+ confidence_str = f"{confidence:.4f} (LOW)"
250
+
251
+ print(f"{display_name:30} {top_cat:20} {confidence_str:12}")
252
+
253
+ print("\nBatch processing complete!")
254
+
255
+ except Exception as e:
256
+ print(f"Error in batch prediction: {e}")
257
+
258
+ # ========================================================================
259
+ # PART 3: Quick summary
260
+ # ========================================================================
261
+ print("\n" + "=" * 60)
262
+ print("DEMO SUMMARY")
263
+ print("-" * 40)
264
+
265
+ print(f"✓ Model loaded successfully on: {classifier.device}")
266
+ print(f" Analyzed {len(available_images)} example image(s)")
267
+ print(f"✓ Example categories: {classifier.CATEGORIES}")
268
+
269
+ missing_images = [img for img in DEMO_IMAGES if img not in available_images]
270
+ if missing_images:
271
+ print(f"\nNote: Missing example images: {missing_images}")
272
+ print("To complete the demo, add these images to the directory.")
273
+
274
+ print("\n" + "=" * 60)
275
+ print("Demo completed successfully!")
276
+ print("\nNext steps:")
277
+ print("1. Try your own images by modifying the DEMO_IMAGES list")
278
+ print("2. Use the classifier in your own code:")
279
+ print(" ```python")
280
+ print(" from example_usage import CoherenceClassifier")
281
+ print(" classifier = CoherenceClassifier('your_model.pth')")
282
+ print(" result = classifier.predict('your_image.jpg')")
283
+ print(" ```")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_architecture.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Exact architecture for Coherence Detection Model.
3
+ Uses key matching to validate the safetensors file.
4
+ """
5
+
6
+ import sys
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ # ============================================================================
11
+ # AdaptiveConcatPool2d for FastAI model compatitbility
12
+ # ============================================================================
13
+ class AdaptiveConcatPool2d(nn.Module):
14
+ """FastAI-style adaptive concatenation pooling."""
15
+ def __init__(self, sz=None):
16
+ super().__init__()
17
+ self.ap = nn.AdaptiveAvgPool2d(sz or 1)
18
+ self.mp = nn.AdaptiveMaxPool2d(sz or 1)
19
+
20
+ def forward(self, x):
21
+ return torch.cat([self.mp(x), self.ap(x)], 1)
22
+
23
+ # Also register in __main__ if it exists
24
+ if '__main__' in sys.modules:
25
+ main_module = sys.modules['__main__']
26
+ if not hasattr(main_module, 'AdaptiveConcatPool2d'):
27
+ setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d)
28
+
29
+ # Also register under common FastAI names
30
+ if '__main__' in sys.modules:
31
+ main_module = sys.modules['__main__']
32
+ # Some FastAI models might expect this
33
+ if not hasattr(main_module, 'AdaptiveConcatPool'):
34
+ setattr(main_module, 'AdaptiveConcatPool', AdaptiveConcatPool2d)
35
+
36
+ # ============================================================================
37
+ # Utility function to check torchvision version
38
+ # ============================================================================
39
+ def _get_resnet_backbone():
40
+ """Helper to get ResNet backbone with version-appropriate API."""
41
+ from torchvision.models import resnet34
42
+ import torchvision
43
+
44
+ # Parse version to determine API
45
+ version = torchvision.__version__.split('.')
46
+ major = int(version[0]) if version[0].isdigit() else 0
47
+ minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
48
+
49
+ if major >= 0 and minor >= 13:
50
+ # Use new weights API
51
+ return resnet34(weights=None)
52
+ else:
53
+ # Use old pretrained API
54
+ return resnet34(pretrained=False)
55
+
56
+
57
+ # ============================================================================
58
+ # Clean model with version detection
59
+ # ============================================================================
60
+ class CoherenceDetectionModel(nn.Sequential):
61
+ """
62
+ Clean version that handles torchvision API changes properly.
63
+ """
64
+ def __init__(self, num_classes=3):
65
+ # Get backbone using version-appropriate API
66
+ backbone = _get_resnet_backbone()
67
+ backbone = nn.Sequential(*list(backbone.children())[:-2])
68
+
69
+ # Classifier head
70
+ classifier = nn.Sequential(
71
+ AdaptiveConcatPool2d(),
72
+ nn.Flatten(start_dim=1, end_dim=-1),
73
+ nn.BatchNorm1d(1024),
74
+ nn.Dropout(p=0.25, inplace=False),
75
+ nn.Linear(1024, 512, bias=True),
76
+ nn.ReLU(inplace=True),
77
+ nn.BatchNorm1d(512),
78
+ nn.Dropout(p=0.5, inplace=False),
79
+ nn.Linear(512, num_classes, bias=True)
80
+ )
81
+
82
+ super().__init__(backbone, classifier)
83
+
84
+
85
+ # ============================================================================
86
+ # Loading function
87
+ # ============================================================================
88
+ def load_coherence_model(safetensors_path, device='auto'):
89
+ """
90
+ Load safetensors weights with automatic key remapping.
91
+
92
+ Args:
93
+ safetensors_path: Path to .safetensors file
94
+ device: 'auto', 'cuda', or 'cpu'
95
+ """
96
+ import safetensors.torch
97
+
98
+ # Determine device
99
+ if device == 'auto':
100
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101
+ else:
102
+ device = torch.device(device)
103
+
104
+ # Create model instance
105
+ model = CoherenceDetectionModel(num_classes=3)
106
+
107
+ # Load safetensors
108
+ try:
109
+ state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
110
+ except FileNotFoundError:
111
+ print(f"Error: File '{safetensors_path}' not found.")
112
+ print("Testing with sample model structure...")
113
+ state_dict = model.state_dict()
114
+
115
+ # Load directly (keys should match exactly)
116
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
117
+
118
+ if missing_keys:
119
+ print(f"Warning: Missing keys: {missing_keys}")
120
+
121
+ if unexpected_keys:
122
+ print(f"Warning: Unexpected keys: {unexpected_keys}")
123
+
124
+ if not missing_keys and not unexpected_keys:
125
+ print(f"✓ CoherenceDetectionModel loaded successfully (exact match)")
126
+ else:
127
+ print(f"⚠ CoherenceDetectionModel loaded with key mismatches")
128
+
129
+ model = model.to(device)
130
+ model.eval()
131
+
132
+ print(f" Device: {device}")
133
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
134
+
135
+ return model
136
+
137
+
138
+ # ============================================================================
139
+ # Test functions
140
+ # ============================================================================
141
+ def test_key_matching(safetensors_path="coherence_model.safetensors"):
142
+ """Test that keys match between model and safetensors."""
143
+ import safetensors.torch
144
+
145
+ print("\nTesting key matching...")
146
+
147
+ try:
148
+ state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
149
+ except FileNotFoundError:
150
+ print(f" ⚠ File '{safetensors_path}' not found, using model weights")
151
+ model = CoherenceDetectionModel(num_classes=3)
152
+ state_dict = model.state_dict()
153
+
154
+ print("\nTesting CoherenceDetectionModel:")
155
+ model_clean = CoherenceDetectionModel(num_classes=3)
156
+ missing, unexpected = model_clean.load_state_dict(state_dict, strict=False)
157
+
158
+ if not missing and not unexpected:
159
+ print(" ✅ Load successful (exact key match)")
160
+ else:
161
+ print(f" ⚠ Load completed with issues")
162
+ if missing:
163
+ print(f" Missing keys: {len(missing)}")
164
+ if unexpected:
165
+ print(f" Unexpected keys: {len(unexpected)}")
166
+
167
+ return model_clean if not missing and not unexpected else None
168
+
169
+
170
+ def print_key_samples(safetensors_path="coherence_model.safetensors"):
171
+ """Print sample keys for debugging."""
172
+ import safetensors.torch
173
+
174
+ print("\nKey samples:")
175
+
176
+ try:
177
+ state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
178
+ print("From safetensors file (first 5 keys):")
179
+ for i, key in enumerate(sorted(state_dict.keys())[:5]):
180
+ print(f" {i}: {key}")
181
+ except FileNotFoundError:
182
+ print(f"Safetensors file '{safetensors_path}' not found")
183
+ print("Showing model structure keys instead:")
184
+ state_dict = None
185
+
186
+ print("\nFrom CoherenceDetectionModel:")
187
+ model = CoherenceDetectionModel(num_classes=3)
188
+ for i, key in enumerate(sorted(model.state_dict().keys())[:5]):
189
+ print(f" {i}: {key}")
190
+
191
+ return state_dict
192
+
193
+
194
+ # ============================================================================
195
+ # Version compatibility info
196
+ # ============================================================================
197
+ def print_version_info():
198
+ """Print version information for debugging."""
199
+ import torch
200
+ import torchvision
201
+
202
+ print("\n" + "=" * 60)
203
+ print("Version Information")
204
+ print("=" * 60)
205
+ print(f"Torch: {torch.__version__}")
206
+ print(f"Torchvision: {torchvision.__version__}")
207
+ print(f"CUDA Available: {torch.cuda.is_available()}")
208
+
209
+ # Check API compatibility
210
+ version = torchvision.__version__.split('.')
211
+ major = int(version[0]) if version[0].isdigit() else 0
212
+ minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
213
+
214
+ if major >= 0 and minor >= 13:
215
+ print("✓ Using modern torchvision API (weights parameter)")
216
+ else:
217
+ print("⚠ Using legacy torchvision API (pretrained parameter)")
218
+ print("=" * 60)
219
+
220
+
221
+ if __name__ == "__main__":
222
+ print("=" * 60)
223
+ print("Coherence Detection Model Architecture")
224
+ print("=" * 60)
225
+
226
+ print_version_info()
227
+ state_dict = print_key_samples()
228
+ print("\n" + "=" * 60)
229
+ model = test_key_matching()
230
+
231
+ if model:
232
+ print("\nModel summary:")
233
+ print(f" Backbone layers: {len(model[0])}")
234
+ print(f" Classifier layers: {len(model[1])}")
235
+ print(f" Total sequential blocks: {len(model)}")