File size: 11,154 Bytes
c5a3ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
"""
Minimal example to load and use the Coherence Detection model.
Requires safetensors format with exact architecture.
"""
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import sys
import os
# ============================================================================
# Import the exact architecture
# ============================================================================
try:
from model_architecture import load_coherence_model
print("✓ Imported exact model architecture")
except ImportError as e:
print(f"Error: model_architecture.py not found or has issues: {e}")
print("Please download it from the repository.")
sys.exit(1)
# ============================================================================
# CoherenceClassifier for safetensors
# ============================================================================
class CoherenceClassifier:
"""Loader for coherence detection model (exact architecture)."""
# Categories in alphabetical order (as per training)
CATEGORIES = ["coherent", "incoherent", "semi-incoherent"]
def __init__(self, model_path, device="auto"):
"""
Args:
model_path: Path to .safetensors file
device: "auto", "cuda", or "cpu"
"""
self.model_path = model_path
self.device = self._get_device(device)
self.model = None
self.transform = None
# ImageNet normalization (standard for ResNet)
self.normalize_mean = [0.485, 0.456, 0.406]
self.normalize_std = [0.229, 0.224, 0.225]
self._setup_transforms()
self.load_model()
def _get_device(self, device_str):
"""Determine the best available device."""
if device_str == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(device_str)
def load_model(self):
"""Load model using safetensors and exact architecture."""
print(f"Loading coherence model...")
print(f" File: {os.path.basename(self.model_path)}")
print(f" Device: {self.device}")
# Verify file type
if not self.model_path.lower().endswith('.safetensors'):
print("⚠️ Warning: Expected .safetensors file for secure loading.")
# Load using our helper function
self.model = load_coherence_model(self.model_path, str(self.device))
def _setup_transforms(self):
"""Create image preprocessing pipeline."""
self.transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet standard size
transforms.ToTensor(),
transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)
])
def preprocess_image(self, image_path):
"""Load and preprocess an image for the model."""
try:
# Open and convert to RGB
image = Image.open(image_path).convert('RGB')
# Apply transformations
tensor = self.transform(image)
# Add batch dimension [1, 3, 224, 224]
tensor = tensor.unsqueeze(0).to(self.device)
return tensor
except Exception as e:
print(f"Error processing image {image_path}: {e}")
raise
def predict(self, image_path, return_probs=False):
"""
Make a prediction on an image.
Args:
image_path: Path to the image file
return_probs: If True, return all probabilities; if False, return only the top category
Returns:
Dictionary with predictions or string with top category
"""
if self.model is None:
raise ValueError("Model not loaded. Call load_model() first.")
# Preprocess the image
input_tensor = self.preprocess_image(image_path)
# Run inference
with torch.no_grad():
output = self.model(input_tensor)
# Convert to probabilities using softmax
probabilities = torch.nn.functional.softmax(output, dim=1)
probs = probabilities[0].cpu().numpy()
# Create results dictionary
results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))}
if return_probs:
return results
else:
# Return the category with highest probability
top_idx = np.argmax(probs)
return self.CATEGORIES[top_idx]
def predict_batch(self, image_paths):
"""Make predictions for multiple images."""
return [self.predict(img_path, return_probs=True) for img_path in image_paths]
# ============================================================================
# 3. Example Usage
# ============================================================================
def main():
"""Example demonstrating how to use the classifier with all three test images."""
# Update this path to your actual model file
MODEL_PATH = "coherence_model.safetensors" # Change to your model filename
# All three example images you've provided
DEMO_IMAGES = [
"example_coherent.jpg",
"example_semi_incoherent.jpg",
"example_incoherent.jpg"
]
# Check if model file exists
if not os.path.exists(MODEL_PATH):
print(f"Error: Model file not found at {MODEL_PATH}")
print("Please download the model from the Hugging Face repository:")
print("https://huggingface.co/your-username/your-model-name")
return
# Check which demo images exist
available_images = [img for img in DEMO_IMAGES if os.path.exists(img)]
if not available_images:
print("Error: No example images found.")
print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}")
return
print("=" * 60)
print("Coherence Detection Model - Complete Demo")
print("=" * 60)
# Initialize classifier
print(f"Loading model from: {MODEL_PATH}")
classifier = CoherenceClassifier(MODEL_PATH, device="auto")
print(f"\nFound {len(available_images)} example image(s) for demonstration.")
print("-" * 60)
# ========================================================================
# PART 1: Detailed single image analysis for each available example
# ========================================================================
print("\n1. SINGLE IMAGE ANALYSIS")
print("-" * 40)
for img_file in available_images:
print(f"\nAnalyzing: {img_file}")
print("-" * 30)
try:
# Get full probability distribution
probs = classifier.predict(img_file, return_probs=True)
# Display all probabilities
print("Prediction probabilities:")
for category in classifier.CATEGORIES:
prob = probs[category]
# Visual indicator for high confidence (>70%)
indicator = " ★" if prob > 0.7 else ""
print(f" {category:20} {prob:.4f}{indicator}")
# Get and display top category
top_category = classifier.predict(img_file, return_probs=False)
top_prob = probs[top_category]
print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})")
# Add interpretation note based on image name
if img_file == "example_coherent.jpg":
print("Note: This example should ideally show high 'coherent' probability.")
elif img_file == "example_semi_incoherent.jpg":
print("Note: This example demonstrates borderline/partial coherence issues.")
elif img_file == "example_incoherent.jpg":
print("Note: This example should show obvious generation failures.")
except Exception as e:
print(f"Error analyzing {img_file}: {e}")
continue
# ========================================================================
# PART 2: Batch prediction comparison (if multiple images available)
# ========================================================================
if len(available_images) > 1:
print("\n" + "=" * 60)
print("2. BATCH PREDICTION COMPARISON")
print("-" * 40)
try:
print(f"Running batch prediction on {len(available_images)} images...")
batch_results = classifier.predict_batch(available_images)
# Create a comparison table
print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}")
print("-" * 65)
for img_path, result in zip(available_images, batch_results):
top_cat = max(result, key=result.get)
confidence = result[top_cat]
# Shorten filename if too long
display_name = os.path.basename(img_path)
if len(display_name) > 28:
display_name = display_name[:25] + "..."
# Color code high confidence predictions
if confidence > 0.8:
confidence_str = f"{confidence:.4f} (HIGH)"
elif confidence > 0.6:
confidence_str = f"{confidence:.4f} (MED)"
else:
confidence_str = f"{confidence:.4f} (LOW)"
print(f"{display_name:30} {top_cat:20} {confidence_str:12}")
print("\nBatch processing complete!")
except Exception as e:
print(f"Error in batch prediction: {e}")
# ========================================================================
# PART 3: Quick summary
# ========================================================================
print("\n" + "=" * 60)
print("DEMO SUMMARY")
print("-" * 40)
print(f"✓ Model loaded successfully on: {classifier.device}")
print(f"✓ Analyzed {len(available_images)} example image(s)")
print(f"✓ Example categories: {classifier.CATEGORIES}")
missing_images = [img for img in DEMO_IMAGES if img not in available_images]
if missing_images:
print(f"\nNote: Missing example images: {missing_images}")
print("To complete the demo, add these images to the directory.")
print("\n" + "=" * 60)
print("Demo completed successfully!")
print("\nNext steps:")
print("1. Try your own images by modifying the DEMO_IMAGES list")
print("2. Use the classifier in your own code:")
print(" ```python")
print(" from example_usage import CoherenceClassifier")
print(" classifier = CoherenceClassifier('your_model.pth')")
print(" result = classifier.predict('your_image.jpg')")
print(" ```")
if __name__ == "__main__":
main()
|