Convert to safetensors format and add model architecture
Browse files- .gitignore +1 -0
- README.md +7 -10
- coherence_model.pth → coherence_model.safetensors +2 -2
- example_usage.py +287 -326
- model_architecture.py +235 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
README.md
CHANGED
|
@@ -18,10 +18,10 @@ A ResNet-34 fine-tuned on a personally-curated dataset to classify images into o
|
|
| 18 |
- Incoherent
|
| 19 |
- Semi-Incoherent
|
| 20 |
|
| 21 |
-
**Key Feature**: This model
|
| 22 |
|
| 23 |
## Installation and Usage
|
| 24 |
-
**Important**: This model
|
| 25 |
|
| 26 |
1. **Install**
|
| 27 |
```bash
|
|
@@ -38,7 +38,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
|
| 38 |
pip install -r requirements.txt
|
| 39 |
```
|
| 40 |
|
| 41 |
-
2. **Download and run the example**:
|
| 42 |
```bash
|
| 43 |
python example_usage.py
|
| 44 |
```
|
|
@@ -47,16 +47,13 @@ python example_usage.py
|
|
| 47 |
```python
|
| 48 |
from example_usage import CoherenceClassifier
|
| 49 |
# Initialize with your model
|
| 50 |
-
classifier = CoherenceClassifier("coherence_model.
|
| 51 |
# Predict on an image
|
| 52 |
result = classifier.predict("your_image.jpg", return_probs=True)
|
| 53 |
print(result) # {'coherent': 0.85, 'incoherent': 0.05, 'semi-incoherent': 0.10}
|
| 54 |
```
|
| 55 |
|
| 56 |
-
Note
|
| 57 |
-
|
| 58 |
-
- Use `weights_only=False` (only if you trust the source), OR
|
| 59 |
-
- Register `AdaptiveConcatPool2d` in `__main__` before loading (see example_usage.py).
|
| 60 |
|
| 61 |
## Model Architecture
|
| 62 |
Backbone: ResNet-34 (via FastAI's default)
|
|
@@ -68,7 +65,7 @@ For version 1.0, a small dataset was used of ~20k images in coherent category, a
|
|
| 68 |
|
| 69 |
## Limitations
|
| 70 |
- PyTorch environment
|
| 71 |
-
- `
|
| 72 |
|
| 73 |
Note especially that while an attempt was made at classifying the less obvious but definitely noticeable generation failures like the following into the "semi-incoherent" category, these are much harder to detect and in the current iteration of this model there is no expectation that they will be observed in any particular coherence category, especially when these particular errors occur in a small part of the image.
|
| 74 |
- Extra or missing limbs, fingers, or facial features
|
|
@@ -83,7 +80,7 @@ Note especially that while an attempt was made at classifying the less obvious b
|
|
| 83 |
## Comment
|
| 84 |
Given the low coherence rate of results produced by early image generation models, it was very surprising that a model was not found for this purpose, necessitating the creation of this one for high-volume review scenarios.
|
| 85 |
|
| 86 |
-
Perhaps models such as this one are avoided or seen as improper due to the perceived danger they pose in introduction of bias to image analysis, however it is highly likely that image generators would rather have at least some bias towards coherence and a somewhat clear mind when reviewing their image output than no bias and a mind littered with the psychologically
|
| 87 |
|
| 88 |
## Model Card Authors
|
| 89 |
Tom Hall
|
|
|
|
| 18 |
- Incoherent
|
| 19 |
- Semi-Incoherent
|
| 20 |
|
| 21 |
+
**Key Feature**: This model is provided in safetensors format with a production-ready loading wrapper (`model_architecture.py`) that handles FastAI's `AdaptiveConcatPool2d` layer automatically.
|
| 22 |
|
| 23 |
## Installation and Usage
|
| 24 |
+
**Important**: This model is provided in safetensors format and requires the `model_architecture.py` module for proper loading.
|
| 25 |
|
| 26 |
1. **Install**
|
| 27 |
```bash
|
|
|
|
| 38 |
pip install -r requirements.txt
|
| 39 |
```
|
| 40 |
|
| 41 |
+
2. **Download entire repo and run the example**:
|
| 42 |
```bash
|
| 43 |
python example_usage.py
|
| 44 |
```
|
|
|
|
| 47 |
```python
|
| 48 |
from example_usage import CoherenceClassifier
|
| 49 |
# Initialize with your model
|
| 50 |
+
classifier = CoherenceClassifier("coherence_model.safetensors")
|
| 51 |
# Predict on an image
|
| 52 |
result = classifier.predict("your_image.jpg", return_probs=True)
|
| 53 |
print(result) # {'coherent': 0.85, 'incoherent': 0.05, 'semi-incoherent': 0.10}
|
| 54 |
```
|
| 55 |
|
| 56 |
+
**Note**: The model uses FastAI's `AdaptiveConcatPool2d` layer. Import and use `model_architecture.py` which handles this automatically. The `example_usage.py` script demonstrates the proper import pattern.
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
## Model Architecture
|
| 59 |
Backbone: ResNet-34 (via FastAI's default)
|
|
|
|
| 65 |
|
| 66 |
## Limitations
|
| 67 |
- PyTorch environment
|
| 68 |
+
- Requires `model_architecture.py` module for proper loading (handles FastAI-specific layers automatically)
|
| 69 |
|
| 70 |
Note especially that while an attempt was made at classifying the less obvious but definitely noticeable generation failures like the following into the "semi-incoherent" category, these are much harder to detect and in the current iteration of this model there is no expectation that they will be observed in any particular coherence category, especially when these particular errors occur in a small part of the image.
|
| 71 |
- Extra or missing limbs, fingers, or facial features
|
|
|
|
| 80 |
## Comment
|
| 81 |
Given the low coherence rate of results produced by early image generation models, it was very surprising that a model was not found for this purpose, necessitating the creation of this one for high-volume review scenarios.
|
| 82 |
|
| 83 |
+
Perhaps models such as this one are avoided or seen as improper due to the perceived danger they pose in introduction of bias to image analysis, however it is highly likely that image generators would rather have at least some bias towards coherence and a somewhat clear mind when reviewing their image output than no bias and a mind littered with the psychologically-damaging results of obviously-failed generations which have little to do with the prompter's intent.
|
| 84 |
|
| 85 |
## Model Card Authors
|
| 86 |
Tom Hall
|
coherence_model.pth → coherence_model.safetensors
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2f937854cc4ae58ebbe2e6b451006a5ee8dc65fabe5c32d33ef937cae24724b
|
| 3 |
+
size 87356700
|
example_usage.py
CHANGED
|
@@ -1,326 +1,287 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Minimal example to load and use the
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from torchvision import transforms
|
| 8 |
-
from PIL import Image
|
| 9 |
-
import numpy as np
|
| 10 |
-
import sys
|
| 11 |
-
import os
|
| 12 |
-
|
| 13 |
-
# ============================================================================
|
| 14 |
-
#
|
| 15 |
-
# ============================================================================
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
""
|
| 60 |
-
self.model_path
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
#
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
self.
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
return
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
"
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
#
|
| 219 |
-
#
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
print(
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
print("
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
print("
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
#
|
| 259 |
-
#
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
confidence_str = f"{confidence:.4f} (LOW)"
|
| 289 |
-
|
| 290 |
-
print(f"{display_name:30} {top_cat:20} {confidence_str:12}")
|
| 291 |
-
|
| 292 |
-
print("\nBatch processing complete!")
|
| 293 |
-
|
| 294 |
-
except Exception as e:
|
| 295 |
-
print(f"Error in batch prediction: {e}")
|
| 296 |
-
|
| 297 |
-
# ========================================================================
|
| 298 |
-
# PART 3: Quick summary
|
| 299 |
-
# ========================================================================
|
| 300 |
-
print("\n" + "=" * 60)
|
| 301 |
-
print("DEMO SUMMARY")
|
| 302 |
-
print("-" * 40)
|
| 303 |
-
|
| 304 |
-
print(f"✓ Model loaded successfully on: {classifier.device}")
|
| 305 |
-
print(f"✓ Analyzed {len(available_images)} example image(s)")
|
| 306 |
-
print(f"✓ Example categories: {classifier.CATEGORIES}")
|
| 307 |
-
|
| 308 |
-
missing_images = [img for img in DEMO_IMAGES if img not in available_images]
|
| 309 |
-
if missing_images:
|
| 310 |
-
print(f"\nNote: Missing example images: {missing_images}")
|
| 311 |
-
print("To complete the demo, add these images to the directory.")
|
| 312 |
-
|
| 313 |
-
print("\n" + "=" * 60)
|
| 314 |
-
print("Demo completed successfully!")
|
| 315 |
-
print("\nNext steps:")
|
| 316 |
-
print("1. Try your own images by modifying the DEMO_IMAGES list")
|
| 317 |
-
print("2. Use the classifier in your own code:")
|
| 318 |
-
print(" ```python")
|
| 319 |
-
print(" from example_usage import CoherenceClassifier")
|
| 320 |
-
print(" classifier = CoherenceClassifier('your_model.pth')")
|
| 321 |
-
print(" result = classifier.predict('your_image.jpg')")
|
| 322 |
-
print(" ```")
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
if __name__ == "__main__":
|
| 326 |
-
main()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal example to load and use the Coherence Detection model.
|
| 3 |
+
Requires safetensors format with exact architecture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
# ============================================================================
|
| 14 |
+
# Import the exact architecture
|
| 15 |
+
# ============================================================================
|
| 16 |
+
try:
|
| 17 |
+
from model_architecture import load_coherence_model
|
| 18 |
+
print("✓ Imported exact model architecture")
|
| 19 |
+
except ImportError as e:
|
| 20 |
+
print(f"Error: model_architecture.py not found or has issues: {e}")
|
| 21 |
+
print("Please download it from the repository.")
|
| 22 |
+
sys.exit(1)
|
| 23 |
+
|
| 24 |
+
# ============================================================================
|
| 25 |
+
# CoherenceClassifier for safetensors
|
| 26 |
+
# ============================================================================
|
| 27 |
+
class CoherenceClassifier:
|
| 28 |
+
"""Loader for coherence detection model (exact architecture)."""
|
| 29 |
+
|
| 30 |
+
# Categories in alphabetical order (as per training)
|
| 31 |
+
CATEGORIES = ["coherent", "incoherent", "semi-incoherent"]
|
| 32 |
+
|
| 33 |
+
def __init__(self, model_path, device="auto"):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
model_path: Path to .safetensors file
|
| 37 |
+
device: "auto", "cuda", or "cpu"
|
| 38 |
+
"""
|
| 39 |
+
self.model_path = model_path
|
| 40 |
+
self.device = self._get_device(device)
|
| 41 |
+
self.model = None
|
| 42 |
+
self.transform = None
|
| 43 |
+
|
| 44 |
+
# ImageNet normalization (standard for ResNet)
|
| 45 |
+
self.normalize_mean = [0.485, 0.456, 0.406]
|
| 46 |
+
self.normalize_std = [0.229, 0.224, 0.225]
|
| 47 |
+
|
| 48 |
+
self._setup_transforms()
|
| 49 |
+
self.load_model()
|
| 50 |
+
|
| 51 |
+
def _get_device(self, device_str):
|
| 52 |
+
"""Determine the best available device."""
|
| 53 |
+
if device_str == "auto":
|
| 54 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 55 |
+
return torch.device(device_str)
|
| 56 |
+
|
| 57 |
+
def load_model(self):
|
| 58 |
+
"""Load model using safetensors and exact architecture."""
|
| 59 |
+
print(f"Loading coherence model...")
|
| 60 |
+
print(f" File: {os.path.basename(self.model_path)}")
|
| 61 |
+
print(f" Device: {self.device}")
|
| 62 |
+
|
| 63 |
+
# Verify file type
|
| 64 |
+
if not self.model_path.lower().endswith('.safetensors'):
|
| 65 |
+
print("⚠️ Warning: Expected .safetensors file for secure loading.")
|
| 66 |
+
|
| 67 |
+
# Load using our helper function
|
| 68 |
+
self.model = load_coherence_model(self.model_path, str(self.device))
|
| 69 |
+
|
| 70 |
+
def _setup_transforms(self):
|
| 71 |
+
"""Create image preprocessing pipeline."""
|
| 72 |
+
self.transform = transforms.Compose([
|
| 73 |
+
transforms.Resize((224, 224)), # ResNet standard size
|
| 74 |
+
transforms.ToTensor(),
|
| 75 |
+
transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
def preprocess_image(self, image_path):
|
| 79 |
+
"""Load and preprocess an image for the model."""
|
| 80 |
+
try:
|
| 81 |
+
# Open and convert to RGB
|
| 82 |
+
image = Image.open(image_path).convert('RGB')
|
| 83 |
+
|
| 84 |
+
# Apply transformations
|
| 85 |
+
tensor = self.transform(image)
|
| 86 |
+
|
| 87 |
+
# Add batch dimension [1, 3, 224, 224]
|
| 88 |
+
tensor = tensor.unsqueeze(0).to(self.device)
|
| 89 |
+
|
| 90 |
+
return tensor
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"Error processing image {image_path}: {e}")
|
| 94 |
+
raise
|
| 95 |
+
|
| 96 |
+
def predict(self, image_path, return_probs=False):
|
| 97 |
+
"""
|
| 98 |
+
Make a prediction on an image.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
image_path: Path to the image file
|
| 102 |
+
return_probs: If True, return all probabilities; if False, return only the top category
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Dictionary with predictions or string with top category
|
| 106 |
+
"""
|
| 107 |
+
if self.model is None:
|
| 108 |
+
raise ValueError("Model not loaded. Call load_model() first.")
|
| 109 |
+
|
| 110 |
+
# Preprocess the image
|
| 111 |
+
input_tensor = self.preprocess_image(image_path)
|
| 112 |
+
|
| 113 |
+
# Run inference
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
output = self.model(input_tensor)
|
| 116 |
+
|
| 117 |
+
# Convert to probabilities using softmax
|
| 118 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
| 119 |
+
probs = probabilities[0].cpu().numpy()
|
| 120 |
+
|
| 121 |
+
# Create results dictionary
|
| 122 |
+
results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))}
|
| 123 |
+
|
| 124 |
+
if return_probs:
|
| 125 |
+
return results
|
| 126 |
+
else:
|
| 127 |
+
# Return the category with highest probability
|
| 128 |
+
top_idx = np.argmax(probs)
|
| 129 |
+
return self.CATEGORIES[top_idx]
|
| 130 |
+
|
| 131 |
+
def predict_batch(self, image_paths):
|
| 132 |
+
"""Make predictions for multiple images."""
|
| 133 |
+
return [self.predict(img_path, return_probs=True) for img_path in image_paths]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ============================================================================
|
| 137 |
+
# 3. Example Usage
|
| 138 |
+
# ============================================================================
|
| 139 |
+
def main():
|
| 140 |
+
"""Example demonstrating how to use the classifier with all three test images."""
|
| 141 |
+
|
| 142 |
+
# Update this path to your actual model file
|
| 143 |
+
MODEL_PATH = "coherence_model.safetensors" # Change to your model filename
|
| 144 |
+
|
| 145 |
+
# All three example images you've provided
|
| 146 |
+
DEMO_IMAGES = [
|
| 147 |
+
"example_coherent.jpg",
|
| 148 |
+
"example_semi_incoherent.jpg",
|
| 149 |
+
"example_incoherent.jpg"
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# Check if model file exists
|
| 153 |
+
if not os.path.exists(MODEL_PATH):
|
| 154 |
+
print(f"Error: Model file not found at {MODEL_PATH}")
|
| 155 |
+
print("Please download the model from the Hugging Face repository:")
|
| 156 |
+
print("https://huggingface.co/your-username/your-model-name")
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
# Check which demo images exist
|
| 160 |
+
available_images = [img for img in DEMO_IMAGES if os.path.exists(img)]
|
| 161 |
+
|
| 162 |
+
if not available_images:
|
| 163 |
+
print("Error: No example images found.")
|
| 164 |
+
print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}")
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
print("=" * 60)
|
| 168 |
+
print("Coherence Detection Model - Complete Demo")
|
| 169 |
+
print("=" * 60)
|
| 170 |
+
|
| 171 |
+
# Initialize classifier
|
| 172 |
+
print(f"Loading model from: {MODEL_PATH}")
|
| 173 |
+
classifier = CoherenceClassifier(MODEL_PATH, device="auto")
|
| 174 |
+
|
| 175 |
+
print(f"\nFound {len(available_images)} example image(s) for demonstration.")
|
| 176 |
+
print("-" * 60)
|
| 177 |
+
|
| 178 |
+
# ========================================================================
|
| 179 |
+
# PART 1: Detailed single image analysis for each available example
|
| 180 |
+
# ========================================================================
|
| 181 |
+
print("\n1. SINGLE IMAGE ANALYSIS")
|
| 182 |
+
print("-" * 40)
|
| 183 |
+
|
| 184 |
+
for img_file in available_images:
|
| 185 |
+
print(f"\nAnalyzing: {img_file}")
|
| 186 |
+
print("-" * 30)
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
# Get full probability distribution
|
| 190 |
+
probs = classifier.predict(img_file, return_probs=True)
|
| 191 |
+
|
| 192 |
+
# Display all probabilities
|
| 193 |
+
print("Prediction probabilities:")
|
| 194 |
+
for category in classifier.CATEGORIES:
|
| 195 |
+
prob = probs[category]
|
| 196 |
+
# Visual indicator for high confidence (>70%)
|
| 197 |
+
indicator = " ★" if prob > 0.7 else ""
|
| 198 |
+
print(f" {category:20} {prob:.4f}{indicator}")
|
| 199 |
+
|
| 200 |
+
# Get and display top category
|
| 201 |
+
top_category = classifier.predict(img_file, return_probs=False)
|
| 202 |
+
top_prob = probs[top_category]
|
| 203 |
+
|
| 204 |
+
print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})")
|
| 205 |
+
|
| 206 |
+
# Add interpretation note based on image name
|
| 207 |
+
if img_file == "example_coherent.jpg":
|
| 208 |
+
print("Note: This example should ideally show high 'coherent' probability.")
|
| 209 |
+
elif img_file == "example_semi_incoherent.jpg":
|
| 210 |
+
print("Note: This example demonstrates borderline/partial coherence issues.")
|
| 211 |
+
elif img_file == "example_incoherent.jpg":
|
| 212 |
+
print("Note: This example should show obvious generation failures.")
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"Error analyzing {img_file}: {e}")
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
# ========================================================================
|
| 219 |
+
# PART 2: Batch prediction comparison (if multiple images available)
|
| 220 |
+
# ========================================================================
|
| 221 |
+
if len(available_images) > 1:
|
| 222 |
+
print("\n" + "=" * 60)
|
| 223 |
+
print("2. BATCH PREDICTION COMPARISON")
|
| 224 |
+
print("-" * 40)
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
print(f"Running batch prediction on {len(available_images)} images...")
|
| 228 |
+
batch_results = classifier.predict_batch(available_images)
|
| 229 |
+
|
| 230 |
+
# Create a comparison table
|
| 231 |
+
print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}")
|
| 232 |
+
print("-" * 65)
|
| 233 |
+
|
| 234 |
+
for img_path, result in zip(available_images, batch_results):
|
| 235 |
+
top_cat = max(result, key=result.get)
|
| 236 |
+
confidence = result[top_cat]
|
| 237 |
+
|
| 238 |
+
# Shorten filename if too long
|
| 239 |
+
display_name = os.path.basename(img_path)
|
| 240 |
+
if len(display_name) > 28:
|
| 241 |
+
display_name = display_name[:25] + "..."
|
| 242 |
+
|
| 243 |
+
# Color code high confidence predictions
|
| 244 |
+
if confidence > 0.8:
|
| 245 |
+
confidence_str = f"{confidence:.4f} (HIGH)"
|
| 246 |
+
elif confidence > 0.6:
|
| 247 |
+
confidence_str = f"{confidence:.4f} (MED)"
|
| 248 |
+
else:
|
| 249 |
+
confidence_str = f"{confidence:.4f} (LOW)"
|
| 250 |
+
|
| 251 |
+
print(f"{display_name:30} {top_cat:20} {confidence_str:12}")
|
| 252 |
+
|
| 253 |
+
print("\nBatch processing complete!")
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Error in batch prediction: {e}")
|
| 257 |
+
|
| 258 |
+
# ========================================================================
|
| 259 |
+
# PART 3: Quick summary
|
| 260 |
+
# ========================================================================
|
| 261 |
+
print("\n" + "=" * 60)
|
| 262 |
+
print("DEMO SUMMARY")
|
| 263 |
+
print("-" * 40)
|
| 264 |
+
|
| 265 |
+
print(f"✓ Model loaded successfully on: {classifier.device}")
|
| 266 |
+
print(f"✓ Analyzed {len(available_images)} example image(s)")
|
| 267 |
+
print(f"✓ Example categories: {classifier.CATEGORIES}")
|
| 268 |
+
|
| 269 |
+
missing_images = [img for img in DEMO_IMAGES if img not in available_images]
|
| 270 |
+
if missing_images:
|
| 271 |
+
print(f"\nNote: Missing example images: {missing_images}")
|
| 272 |
+
print("To complete the demo, add these images to the directory.")
|
| 273 |
+
|
| 274 |
+
print("\n" + "=" * 60)
|
| 275 |
+
print("Demo completed successfully!")
|
| 276 |
+
print("\nNext steps:")
|
| 277 |
+
print("1. Try your own images by modifying the DEMO_IMAGES list")
|
| 278 |
+
print("2. Use the classifier in your own code:")
|
| 279 |
+
print(" ```python")
|
| 280 |
+
print(" from example_usage import CoherenceClassifier")
|
| 281 |
+
print(" classifier = CoherenceClassifier('your_model.pth')")
|
| 282 |
+
print(" result = classifier.predict('your_image.jpg')")
|
| 283 |
+
print(" ```")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_architecture.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Exact architecture for Coherence Detection Model.
|
| 3 |
+
Uses key matching to validate the safetensors file.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
# ============================================================================
|
| 11 |
+
# AdaptiveConcatPool2d for FastAI model compatitbility
|
| 12 |
+
# ============================================================================
|
| 13 |
+
class AdaptiveConcatPool2d(nn.Module):
|
| 14 |
+
"""FastAI-style adaptive concatenation pooling."""
|
| 15 |
+
def __init__(self, sz=None):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.ap = nn.AdaptiveAvgPool2d(sz or 1)
|
| 18 |
+
self.mp = nn.AdaptiveMaxPool2d(sz or 1)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return torch.cat([self.mp(x), self.ap(x)], 1)
|
| 22 |
+
|
| 23 |
+
# Also register in __main__ if it exists
|
| 24 |
+
if '__main__' in sys.modules:
|
| 25 |
+
main_module = sys.modules['__main__']
|
| 26 |
+
if not hasattr(main_module, 'AdaptiveConcatPool2d'):
|
| 27 |
+
setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d)
|
| 28 |
+
|
| 29 |
+
# Also register under common FastAI names
|
| 30 |
+
if '__main__' in sys.modules:
|
| 31 |
+
main_module = sys.modules['__main__']
|
| 32 |
+
# Some FastAI models might expect this
|
| 33 |
+
if not hasattr(main_module, 'AdaptiveConcatPool'):
|
| 34 |
+
setattr(main_module, 'AdaptiveConcatPool', AdaptiveConcatPool2d)
|
| 35 |
+
|
| 36 |
+
# ============================================================================
|
| 37 |
+
# Utility function to check torchvision version
|
| 38 |
+
# ============================================================================
|
| 39 |
+
def _get_resnet_backbone():
|
| 40 |
+
"""Helper to get ResNet backbone with version-appropriate API."""
|
| 41 |
+
from torchvision.models import resnet34
|
| 42 |
+
import torchvision
|
| 43 |
+
|
| 44 |
+
# Parse version to determine API
|
| 45 |
+
version = torchvision.__version__.split('.')
|
| 46 |
+
major = int(version[0]) if version[0].isdigit() else 0
|
| 47 |
+
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
|
| 48 |
+
|
| 49 |
+
if major >= 0 and minor >= 13:
|
| 50 |
+
# Use new weights API
|
| 51 |
+
return resnet34(weights=None)
|
| 52 |
+
else:
|
| 53 |
+
# Use old pretrained API
|
| 54 |
+
return resnet34(pretrained=False)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ============================================================================
|
| 58 |
+
# Clean model with version detection
|
| 59 |
+
# ============================================================================
|
| 60 |
+
class CoherenceDetectionModel(nn.Sequential):
|
| 61 |
+
"""
|
| 62 |
+
Clean version that handles torchvision API changes properly.
|
| 63 |
+
"""
|
| 64 |
+
def __init__(self, num_classes=3):
|
| 65 |
+
# Get backbone using version-appropriate API
|
| 66 |
+
backbone = _get_resnet_backbone()
|
| 67 |
+
backbone = nn.Sequential(*list(backbone.children())[:-2])
|
| 68 |
+
|
| 69 |
+
# Classifier head
|
| 70 |
+
classifier = nn.Sequential(
|
| 71 |
+
AdaptiveConcatPool2d(),
|
| 72 |
+
nn.Flatten(start_dim=1, end_dim=-1),
|
| 73 |
+
nn.BatchNorm1d(1024),
|
| 74 |
+
nn.Dropout(p=0.25, inplace=False),
|
| 75 |
+
nn.Linear(1024, 512, bias=True),
|
| 76 |
+
nn.ReLU(inplace=True),
|
| 77 |
+
nn.BatchNorm1d(512),
|
| 78 |
+
nn.Dropout(p=0.5, inplace=False),
|
| 79 |
+
nn.Linear(512, num_classes, bias=True)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
super().__init__(backbone, classifier)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ============================================================================
|
| 86 |
+
# Loading function
|
| 87 |
+
# ============================================================================
|
| 88 |
+
def load_coherence_model(safetensors_path, device='auto'):
|
| 89 |
+
"""
|
| 90 |
+
Load safetensors weights with automatic key remapping.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
safetensors_path: Path to .safetensors file
|
| 94 |
+
device: 'auto', 'cuda', or 'cpu'
|
| 95 |
+
"""
|
| 96 |
+
import safetensors.torch
|
| 97 |
+
|
| 98 |
+
# Determine device
|
| 99 |
+
if device == 'auto':
|
| 100 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 101 |
+
else:
|
| 102 |
+
device = torch.device(device)
|
| 103 |
+
|
| 104 |
+
# Create model instance
|
| 105 |
+
model = CoherenceDetectionModel(num_classes=3)
|
| 106 |
+
|
| 107 |
+
# Load safetensors
|
| 108 |
+
try:
|
| 109 |
+
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
|
| 110 |
+
except FileNotFoundError:
|
| 111 |
+
print(f"Error: File '{safetensors_path}' not found.")
|
| 112 |
+
print("Testing with sample model structure...")
|
| 113 |
+
state_dict = model.state_dict()
|
| 114 |
+
|
| 115 |
+
# Load directly (keys should match exactly)
|
| 116 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 117 |
+
|
| 118 |
+
if missing_keys:
|
| 119 |
+
print(f"Warning: Missing keys: {missing_keys}")
|
| 120 |
+
|
| 121 |
+
if unexpected_keys:
|
| 122 |
+
print(f"Warning: Unexpected keys: {unexpected_keys}")
|
| 123 |
+
|
| 124 |
+
if not missing_keys and not unexpected_keys:
|
| 125 |
+
print(f"✓ CoherenceDetectionModel loaded successfully (exact match)")
|
| 126 |
+
else:
|
| 127 |
+
print(f"⚠ CoherenceDetectionModel loaded with key mismatches")
|
| 128 |
+
|
| 129 |
+
model = model.to(device)
|
| 130 |
+
model.eval()
|
| 131 |
+
|
| 132 |
+
print(f" Device: {device}")
|
| 133 |
+
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 134 |
+
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ============================================================================
|
| 139 |
+
# Test functions
|
| 140 |
+
# ============================================================================
|
| 141 |
+
def test_key_matching(safetensors_path="coherence_model.safetensors"):
|
| 142 |
+
"""Test that keys match between model and safetensors."""
|
| 143 |
+
import safetensors.torch
|
| 144 |
+
|
| 145 |
+
print("\nTesting key matching...")
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
|
| 149 |
+
except FileNotFoundError:
|
| 150 |
+
print(f" ⚠ File '{safetensors_path}' not found, using model weights")
|
| 151 |
+
model = CoherenceDetectionModel(num_classes=3)
|
| 152 |
+
state_dict = model.state_dict()
|
| 153 |
+
|
| 154 |
+
print("\nTesting CoherenceDetectionModel:")
|
| 155 |
+
model_clean = CoherenceDetectionModel(num_classes=3)
|
| 156 |
+
missing, unexpected = model_clean.load_state_dict(state_dict, strict=False)
|
| 157 |
+
|
| 158 |
+
if not missing and not unexpected:
|
| 159 |
+
print(" ✅ Load successful (exact key match)")
|
| 160 |
+
else:
|
| 161 |
+
print(f" ⚠ Load completed with issues")
|
| 162 |
+
if missing:
|
| 163 |
+
print(f" Missing keys: {len(missing)}")
|
| 164 |
+
if unexpected:
|
| 165 |
+
print(f" Unexpected keys: {len(unexpected)}")
|
| 166 |
+
|
| 167 |
+
return model_clean if not missing and not unexpected else None
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def print_key_samples(safetensors_path="coherence_model.safetensors"):
|
| 171 |
+
"""Print sample keys for debugging."""
|
| 172 |
+
import safetensors.torch
|
| 173 |
+
|
| 174 |
+
print("\nKey samples:")
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
|
| 178 |
+
print("From safetensors file (first 5 keys):")
|
| 179 |
+
for i, key in enumerate(sorted(state_dict.keys())[:5]):
|
| 180 |
+
print(f" {i}: {key}")
|
| 181 |
+
except FileNotFoundError:
|
| 182 |
+
print(f"Safetensors file '{safetensors_path}' not found")
|
| 183 |
+
print("Showing model structure keys instead:")
|
| 184 |
+
state_dict = None
|
| 185 |
+
|
| 186 |
+
print("\nFrom CoherenceDetectionModel:")
|
| 187 |
+
model = CoherenceDetectionModel(num_classes=3)
|
| 188 |
+
for i, key in enumerate(sorted(model.state_dict().keys())[:5]):
|
| 189 |
+
print(f" {i}: {key}")
|
| 190 |
+
|
| 191 |
+
return state_dict
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ============================================================================
|
| 195 |
+
# Version compatibility info
|
| 196 |
+
# ============================================================================
|
| 197 |
+
def print_version_info():
|
| 198 |
+
"""Print version information for debugging."""
|
| 199 |
+
import torch
|
| 200 |
+
import torchvision
|
| 201 |
+
|
| 202 |
+
print("\n" + "=" * 60)
|
| 203 |
+
print("Version Information")
|
| 204 |
+
print("=" * 60)
|
| 205 |
+
print(f"Torch: {torch.__version__}")
|
| 206 |
+
print(f"Torchvision: {torchvision.__version__}")
|
| 207 |
+
print(f"CUDA Available: {torch.cuda.is_available()}")
|
| 208 |
+
|
| 209 |
+
# Check API compatibility
|
| 210 |
+
version = torchvision.__version__.split('.')
|
| 211 |
+
major = int(version[0]) if version[0].isdigit() else 0
|
| 212 |
+
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
|
| 213 |
+
|
| 214 |
+
if major >= 0 and minor >= 13:
|
| 215 |
+
print("✓ Using modern torchvision API (weights parameter)")
|
| 216 |
+
else:
|
| 217 |
+
print("⚠ Using legacy torchvision API (pretrained parameter)")
|
| 218 |
+
print("=" * 60)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
print("=" * 60)
|
| 223 |
+
print("Coherence Detection Model Architecture")
|
| 224 |
+
print("=" * 60)
|
| 225 |
+
|
| 226 |
+
print_version_info()
|
| 227 |
+
state_dict = print_key_samples()
|
| 228 |
+
print("\n" + "=" * 60)
|
| 229 |
+
model = test_key_matching()
|
| 230 |
+
|
| 231 |
+
if model:
|
| 232 |
+
print("\nModel summary:")
|
| 233 |
+
print(f" Backbone layers: {len(model[0])}")
|
| 234 |
+
print(f" Classifier layers: {len(model[1])}")
|
| 235 |
+
print(f" Total sequential blocks: {len(model)}")
|