""" Minimal example to load and use the Coherence Detection model. Requires safetensors format with exact architecture. """ import torch from torchvision import transforms from PIL import Image import numpy as np import sys import os # ============================================================================ # Import the exact architecture # ============================================================================ try: from model_architecture import load_coherence_model print("✓ Imported exact model architecture") except ImportError as e: print(f"Error: model_architecture.py not found or has issues: {e}") print("Please download it from the repository.") sys.exit(1) # ============================================================================ # CoherenceClassifier for safetensors # ============================================================================ class CoherenceClassifier: """Loader for coherence detection model (exact architecture).""" # Categories in alphabetical order (as per training) CATEGORIES = ["coherent", "incoherent", "semi-incoherent"] def __init__(self, model_path, device="auto"): """ Args: model_path: Path to .safetensors file device: "auto", "cuda", or "cpu" """ self.model_path = model_path self.device = self._get_device(device) self.model = None self.transform = None # ImageNet normalization (standard for ResNet) self.normalize_mean = [0.485, 0.456, 0.406] self.normalize_std = [0.229, 0.224, 0.225] self._setup_transforms() self.load_model() def _get_device(self, device_str): """Determine the best available device.""" if device_str == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(device_str) def load_model(self): """Load model using safetensors and exact architecture.""" print(f"Loading coherence model...") print(f" File: {os.path.basename(self.model_path)}") print(f" Device: {self.device}") # Verify file type if not self.model_path.lower().endswith('.safetensors'): print("⚠️ Warning: Expected .safetensors file for secure loading.") # Load using our helper function self.model = load_coherence_model(self.model_path, str(self.device)) def _setup_transforms(self): """Create image preprocessing pipeline.""" self.transform = transforms.Compose([ transforms.Resize((224, 224)), # ResNet standard size transforms.ToTensor(), transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std) ]) def preprocess_image(self, image_path): """Load and preprocess an image for the model.""" try: # Open and convert to RGB image = Image.open(image_path).convert('RGB') # Apply transformations tensor = self.transform(image) # Add batch dimension [1, 3, 224, 224] tensor = tensor.unsqueeze(0).to(self.device) return tensor except Exception as e: print(f"Error processing image {image_path}: {e}") raise def predict(self, image_path, return_probs=False): """ Make a prediction on an image. Args: image_path: Path to the image file return_probs: If True, return all probabilities; if False, return only the top category Returns: Dictionary with predictions or string with top category """ if self.model is None: raise ValueError("Model not loaded. Call load_model() first.") # Preprocess the image input_tensor = self.preprocess_image(image_path) # Run inference with torch.no_grad(): output = self.model(input_tensor) # Convert to probabilities using softmax probabilities = torch.nn.functional.softmax(output, dim=1) probs = probabilities[0].cpu().numpy() # Create results dictionary results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))} if return_probs: return results else: # Return the category with highest probability top_idx = np.argmax(probs) return self.CATEGORIES[top_idx] def predict_batch(self, image_paths): """Make predictions for multiple images.""" return [self.predict(img_path, return_probs=True) for img_path in image_paths] # ============================================================================ # 3. Example Usage # ============================================================================ def main(): """Example demonstrating how to use the classifier with all three test images.""" # Update this path to your actual model file MODEL_PATH = "coherence_model.safetensors" # Change to your model filename # All three example images you've provided DEMO_IMAGES = [ "example_coherent.jpg", "example_semi_incoherent.jpg", "example_incoherent.jpg" ] # Check if model file exists if not os.path.exists(MODEL_PATH): print(f"Error: Model file not found at {MODEL_PATH}") print("Please download the model from the Hugging Face repository:") print("https://huggingface.co/your-username/your-model-name") return # Check which demo images exist available_images = [img for img in DEMO_IMAGES if os.path.exists(img)] if not available_images: print("Error: No example images found.") print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}") return print("=" * 60) print("Coherence Detection Model - Complete Demo") print("=" * 60) # Initialize classifier print(f"Loading model from: {MODEL_PATH}") classifier = CoherenceClassifier(MODEL_PATH, device="auto") print(f"\nFound {len(available_images)} example image(s) for demonstration.") print("-" * 60) # ======================================================================== # PART 1: Detailed single image analysis for each available example # ======================================================================== print("\n1. SINGLE IMAGE ANALYSIS") print("-" * 40) for img_file in available_images: print(f"\nAnalyzing: {img_file}") print("-" * 30) try: # Get full probability distribution probs = classifier.predict(img_file, return_probs=True) # Display all probabilities print("Prediction probabilities:") for category in classifier.CATEGORIES: prob = probs[category] # Visual indicator for high confidence (>70%) indicator = " ★" if prob > 0.7 else "" print(f" {category:20} {prob:.4f}{indicator}") # Get and display top category top_category = classifier.predict(img_file, return_probs=False) top_prob = probs[top_category] print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})") # Add interpretation note based on image name if img_file == "example_coherent.jpg": print("Note: This example should ideally show high 'coherent' probability.") elif img_file == "example_semi_incoherent.jpg": print("Note: This example demonstrates borderline/partial coherence issues.") elif img_file == "example_incoherent.jpg": print("Note: This example should show obvious generation failures.") except Exception as e: print(f"Error analyzing {img_file}: {e}") continue # ======================================================================== # PART 2: Batch prediction comparison (if multiple images available) # ======================================================================== if len(available_images) > 1: print("\n" + "=" * 60) print("2. BATCH PREDICTION COMPARISON") print("-" * 40) try: print(f"Running batch prediction on {len(available_images)} images...") batch_results = classifier.predict_batch(available_images) # Create a comparison table print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}") print("-" * 65) for img_path, result in zip(available_images, batch_results): top_cat = max(result, key=result.get) confidence = result[top_cat] # Shorten filename if too long display_name = os.path.basename(img_path) if len(display_name) > 28: display_name = display_name[:25] + "..." # Color code high confidence predictions if confidence > 0.8: confidence_str = f"{confidence:.4f} (HIGH)" elif confidence > 0.6: confidence_str = f"{confidence:.4f} (MED)" else: confidence_str = f"{confidence:.4f} (LOW)" print(f"{display_name:30} {top_cat:20} {confidence_str:12}") print("\nBatch processing complete!") except Exception as e: print(f"Error in batch prediction: {e}") # ======================================================================== # PART 3: Quick summary # ======================================================================== print("\n" + "=" * 60) print("DEMO SUMMARY") print("-" * 40) print(f"✓ Model loaded successfully on: {classifier.device}") print(f"✓ Analyzed {len(available_images)} example image(s)") print(f"✓ Example categories: {classifier.CATEGORIES}") missing_images = [img for img in DEMO_IMAGES if img not in available_images] if missing_images: print(f"\nNote: Missing example images: {missing_images}") print("To complete the demo, add these images to the directory.") print("\n" + "=" * 60) print("Demo completed successfully!") print("\nNext steps:") print("1. Try your own images by modifying the DEMO_IMAGES list") print("2. Use the classifier in your own code:") print(" ```python") print(" from example_usage import CoherenceClassifier") print(" classifier = CoherenceClassifier('your_model.pth')") print(" result = classifier.predict('your_image.jpg')") print(" ```") if __name__ == "__main__": main()