coherence_detection / example_usage.py
reddesert's picture
Convert to safetensors format and add model architecture
c5a3ef9
"""
Minimal example to load and use the Coherence Detection model.
Requires safetensors format with exact architecture.
"""
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import sys
import os
# ============================================================================
# Import the exact architecture
# ============================================================================
try:
from model_architecture import load_coherence_model
print("✓ Imported exact model architecture")
except ImportError as e:
print(f"Error: model_architecture.py not found or has issues: {e}")
print("Please download it from the repository.")
sys.exit(1)
# ============================================================================
# CoherenceClassifier for safetensors
# ============================================================================
class CoherenceClassifier:
"""Loader for coherence detection model (exact architecture)."""
# Categories in alphabetical order (as per training)
CATEGORIES = ["coherent", "incoherent", "semi-incoherent"]
def __init__(self, model_path, device="auto"):
"""
Args:
model_path: Path to .safetensors file
device: "auto", "cuda", or "cpu"
"""
self.model_path = model_path
self.device = self._get_device(device)
self.model = None
self.transform = None
# ImageNet normalization (standard for ResNet)
self.normalize_mean = [0.485, 0.456, 0.406]
self.normalize_std = [0.229, 0.224, 0.225]
self._setup_transforms()
self.load_model()
def _get_device(self, device_str):
"""Determine the best available device."""
if device_str == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(device_str)
def load_model(self):
"""Load model using safetensors and exact architecture."""
print(f"Loading coherence model...")
print(f" File: {os.path.basename(self.model_path)}")
print(f" Device: {self.device}")
# Verify file type
if not self.model_path.lower().endswith('.safetensors'):
print("⚠️ Warning: Expected .safetensors file for secure loading.")
# Load using our helper function
self.model = load_coherence_model(self.model_path, str(self.device))
def _setup_transforms(self):
"""Create image preprocessing pipeline."""
self.transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet standard size
transforms.ToTensor(),
transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)
])
def preprocess_image(self, image_path):
"""Load and preprocess an image for the model."""
try:
# Open and convert to RGB
image = Image.open(image_path).convert('RGB')
# Apply transformations
tensor = self.transform(image)
# Add batch dimension [1, 3, 224, 224]
tensor = tensor.unsqueeze(0).to(self.device)
return tensor
except Exception as e:
print(f"Error processing image {image_path}: {e}")
raise
def predict(self, image_path, return_probs=False):
"""
Make a prediction on an image.
Args:
image_path: Path to the image file
return_probs: If True, return all probabilities; if False, return only the top category
Returns:
Dictionary with predictions or string with top category
"""
if self.model is None:
raise ValueError("Model not loaded. Call load_model() first.")
# Preprocess the image
input_tensor = self.preprocess_image(image_path)
# Run inference
with torch.no_grad():
output = self.model(input_tensor)
# Convert to probabilities using softmax
probabilities = torch.nn.functional.softmax(output, dim=1)
probs = probabilities[0].cpu().numpy()
# Create results dictionary
results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))}
if return_probs:
return results
else:
# Return the category with highest probability
top_idx = np.argmax(probs)
return self.CATEGORIES[top_idx]
def predict_batch(self, image_paths):
"""Make predictions for multiple images."""
return [self.predict(img_path, return_probs=True) for img_path in image_paths]
# ============================================================================
# 3. Example Usage
# ============================================================================
def main():
"""Example demonstrating how to use the classifier with all three test images."""
# Update this path to your actual model file
MODEL_PATH = "coherence_model.safetensors" # Change to your model filename
# All three example images you've provided
DEMO_IMAGES = [
"example_coherent.jpg",
"example_semi_incoherent.jpg",
"example_incoherent.jpg"
]
# Check if model file exists
if not os.path.exists(MODEL_PATH):
print(f"Error: Model file not found at {MODEL_PATH}")
print("Please download the model from the Hugging Face repository:")
print("https://huggingface.co/your-username/your-model-name")
return
# Check which demo images exist
available_images = [img for img in DEMO_IMAGES if os.path.exists(img)]
if not available_images:
print("Error: No example images found.")
print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}")
return
print("=" * 60)
print("Coherence Detection Model - Complete Demo")
print("=" * 60)
# Initialize classifier
print(f"Loading model from: {MODEL_PATH}")
classifier = CoherenceClassifier(MODEL_PATH, device="auto")
print(f"\nFound {len(available_images)} example image(s) for demonstration.")
print("-" * 60)
# ========================================================================
# PART 1: Detailed single image analysis for each available example
# ========================================================================
print("\n1. SINGLE IMAGE ANALYSIS")
print("-" * 40)
for img_file in available_images:
print(f"\nAnalyzing: {img_file}")
print("-" * 30)
try:
# Get full probability distribution
probs = classifier.predict(img_file, return_probs=True)
# Display all probabilities
print("Prediction probabilities:")
for category in classifier.CATEGORIES:
prob = probs[category]
# Visual indicator for high confidence (>70%)
indicator = " ★" if prob > 0.7 else ""
print(f" {category:20} {prob:.4f}{indicator}")
# Get and display top category
top_category = classifier.predict(img_file, return_probs=False)
top_prob = probs[top_category]
print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})")
# Add interpretation note based on image name
if img_file == "example_coherent.jpg":
print("Note: This example should ideally show high 'coherent' probability.")
elif img_file == "example_semi_incoherent.jpg":
print("Note: This example demonstrates borderline/partial coherence issues.")
elif img_file == "example_incoherent.jpg":
print("Note: This example should show obvious generation failures.")
except Exception as e:
print(f"Error analyzing {img_file}: {e}")
continue
# ========================================================================
# PART 2: Batch prediction comparison (if multiple images available)
# ========================================================================
if len(available_images) > 1:
print("\n" + "=" * 60)
print("2. BATCH PREDICTION COMPARISON")
print("-" * 40)
try:
print(f"Running batch prediction on {len(available_images)} images...")
batch_results = classifier.predict_batch(available_images)
# Create a comparison table
print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}")
print("-" * 65)
for img_path, result in zip(available_images, batch_results):
top_cat = max(result, key=result.get)
confidence = result[top_cat]
# Shorten filename if too long
display_name = os.path.basename(img_path)
if len(display_name) > 28:
display_name = display_name[:25] + "..."
# Color code high confidence predictions
if confidence > 0.8:
confidence_str = f"{confidence:.4f} (HIGH)"
elif confidence > 0.6:
confidence_str = f"{confidence:.4f} (MED)"
else:
confidence_str = f"{confidence:.4f} (LOW)"
print(f"{display_name:30} {top_cat:20} {confidence_str:12}")
print("\nBatch processing complete!")
except Exception as e:
print(f"Error in batch prediction: {e}")
# ========================================================================
# PART 3: Quick summary
# ========================================================================
print("\n" + "=" * 60)
print("DEMO SUMMARY")
print("-" * 40)
print(f"✓ Model loaded successfully on: {classifier.device}")
print(f"✓ Analyzed {len(available_images)} example image(s)")
print(f"✓ Example categories: {classifier.CATEGORIES}")
missing_images = [img for img in DEMO_IMAGES if img not in available_images]
if missing_images:
print(f"\nNote: Missing example images: {missing_images}")
print("To complete the demo, add these images to the directory.")
print("\n" + "=" * 60)
print("Demo completed successfully!")
print("\nNext steps:")
print("1. Try your own images by modifying the DEMO_IMAGES list")
print("2. Use the classifier in your own code:")
print(" ```python")
print(" from example_usage import CoherenceClassifier")
print(" classifier = CoherenceClassifier('your_model.pth')")
print(" result = classifier.predict('your_image.jpg')")
print(" ```")
if __name__ == "__main__":
main()