|
|
""" |
|
|
Minimal example to load and use the Coherence Detection model. |
|
|
Requires safetensors format with exact architecture. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from model_architecture import load_coherence_model |
|
|
print("✓ Imported exact model architecture") |
|
|
except ImportError as e: |
|
|
print(f"Error: model_architecture.py not found or has issues: {e}") |
|
|
print("Please download it from the repository.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoherenceClassifier: |
|
|
"""Loader for coherence detection model (exact architecture).""" |
|
|
|
|
|
|
|
|
CATEGORIES = ["coherent", "incoherent", "semi-incoherent"] |
|
|
|
|
|
def __init__(self, model_path, device="auto"): |
|
|
""" |
|
|
Args: |
|
|
model_path: Path to .safetensors file |
|
|
device: "auto", "cuda", or "cpu" |
|
|
""" |
|
|
self.model_path = model_path |
|
|
self.device = self._get_device(device) |
|
|
self.model = None |
|
|
self.transform = None |
|
|
|
|
|
|
|
|
self.normalize_mean = [0.485, 0.456, 0.406] |
|
|
self.normalize_std = [0.229, 0.224, 0.225] |
|
|
|
|
|
self._setup_transforms() |
|
|
self.load_model() |
|
|
|
|
|
def _get_device(self, device_str): |
|
|
"""Determine the best available device.""" |
|
|
if device_str == "auto": |
|
|
return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
return torch.device(device_str) |
|
|
|
|
|
def load_model(self): |
|
|
"""Load model using safetensors and exact architecture.""" |
|
|
print(f"Loading coherence model...") |
|
|
print(f" File: {os.path.basename(self.model_path)}") |
|
|
print(f" Device: {self.device}") |
|
|
|
|
|
|
|
|
if not self.model_path.lower().endswith('.safetensors'): |
|
|
print("⚠️ Warning: Expected .safetensors file for secure loading.") |
|
|
|
|
|
|
|
|
self.model = load_coherence_model(self.model_path, str(self.device)) |
|
|
|
|
|
def _setup_transforms(self): |
|
|
"""Create image preprocessing pipeline.""" |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std) |
|
|
]) |
|
|
|
|
|
def preprocess_image(self, image_path): |
|
|
"""Load and preprocess an image for the model.""" |
|
|
try: |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
tensor = self.transform(image) |
|
|
|
|
|
|
|
|
tensor = tensor.unsqueeze(0).to(self.device) |
|
|
|
|
|
return tensor |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image {image_path}: {e}") |
|
|
raise |
|
|
|
|
|
def predict(self, image_path, return_probs=False): |
|
|
""" |
|
|
Make a prediction on an image. |
|
|
|
|
|
Args: |
|
|
image_path: Path to the image file |
|
|
return_probs: If True, return all probabilities; if False, return only the top category |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions or string with top category |
|
|
""" |
|
|
if self.model is None: |
|
|
raise ValueError("Model not loaded. Call load_model() first.") |
|
|
|
|
|
|
|
|
input_tensor = self.preprocess_image(image_path) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(input_tensor) |
|
|
|
|
|
|
|
|
probabilities = torch.nn.functional.softmax(output, dim=1) |
|
|
probs = probabilities[0].cpu().numpy() |
|
|
|
|
|
|
|
|
results = {self.CATEGORIES[i]: float(probs[i]) for i in range(len(self.CATEGORIES))} |
|
|
|
|
|
if return_probs: |
|
|
return results |
|
|
else: |
|
|
|
|
|
top_idx = np.argmax(probs) |
|
|
return self.CATEGORIES[top_idx] |
|
|
|
|
|
def predict_batch(self, image_paths): |
|
|
"""Make predictions for multiple images.""" |
|
|
return [self.predict(img_path, return_probs=True) for img_path in image_paths] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example demonstrating how to use the classifier with all three test images.""" |
|
|
|
|
|
|
|
|
MODEL_PATH = "coherence_model.safetensors" |
|
|
|
|
|
|
|
|
DEMO_IMAGES = [ |
|
|
"example_coherent.jpg", |
|
|
"example_semi_incoherent.jpg", |
|
|
"example_incoherent.jpg" |
|
|
] |
|
|
|
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
|
print(f"Error: Model file not found at {MODEL_PATH}") |
|
|
print("Please download the model from the Hugging Face repository:") |
|
|
print("https://huggingface.co/your-username/your-model-name") |
|
|
return |
|
|
|
|
|
|
|
|
available_images = [img for img in DEMO_IMAGES if os.path.exists(img)] |
|
|
|
|
|
if not available_images: |
|
|
print("Error: No example images found.") |
|
|
print(f"Please add one or more of these images to the directory: {DEMO_IMAGES}") |
|
|
return |
|
|
|
|
|
print("=" * 60) |
|
|
print("Coherence Detection Model - Complete Demo") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print(f"Loading model from: {MODEL_PATH}") |
|
|
classifier = CoherenceClassifier(MODEL_PATH, device="auto") |
|
|
|
|
|
print(f"\nFound {len(available_images)} example image(s) for demonstration.") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n1. SINGLE IMAGE ANALYSIS") |
|
|
print("-" * 40) |
|
|
|
|
|
for img_file in available_images: |
|
|
print(f"\nAnalyzing: {img_file}") |
|
|
print("-" * 30) |
|
|
|
|
|
try: |
|
|
|
|
|
probs = classifier.predict(img_file, return_probs=True) |
|
|
|
|
|
|
|
|
print("Prediction probabilities:") |
|
|
for category in classifier.CATEGORIES: |
|
|
prob = probs[category] |
|
|
|
|
|
indicator = " ★" if prob > 0.7 else "" |
|
|
print(f" {category:20} {prob:.4f}{indicator}") |
|
|
|
|
|
|
|
|
top_category = classifier.predict(img_file, return_probs=False) |
|
|
top_prob = probs[top_category] |
|
|
|
|
|
print(f"\nTop prediction: '{top_category}' ({top_prob:.4f})") |
|
|
|
|
|
|
|
|
if img_file == "example_coherent.jpg": |
|
|
print("Note: This example should ideally show high 'coherent' probability.") |
|
|
elif img_file == "example_semi_incoherent.jpg": |
|
|
print("Note: This example demonstrates borderline/partial coherence issues.") |
|
|
elif img_file == "example_incoherent.jpg": |
|
|
print("Note: This example should show obvious generation failures.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error analyzing {img_file}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(available_images) > 1: |
|
|
print("\n" + "=" * 60) |
|
|
print("2. BATCH PREDICTION COMPARISON") |
|
|
print("-" * 40) |
|
|
|
|
|
try: |
|
|
print(f"Running batch prediction on {len(available_images)} images...") |
|
|
batch_results = classifier.predict_batch(available_images) |
|
|
|
|
|
|
|
|
print(f"\n{'Image':30} {'Top Prediction':20} {'Confidence':12}") |
|
|
print("-" * 65) |
|
|
|
|
|
for img_path, result in zip(available_images, batch_results): |
|
|
top_cat = max(result, key=result.get) |
|
|
confidence = result[top_cat] |
|
|
|
|
|
|
|
|
display_name = os.path.basename(img_path) |
|
|
if len(display_name) > 28: |
|
|
display_name = display_name[:25] + "..." |
|
|
|
|
|
|
|
|
if confidence > 0.8: |
|
|
confidence_str = f"{confidence:.4f} (HIGH)" |
|
|
elif confidence > 0.6: |
|
|
confidence_str = f"{confidence:.4f} (MED)" |
|
|
else: |
|
|
confidence_str = f"{confidence:.4f} (LOW)" |
|
|
|
|
|
print(f"{display_name:30} {top_cat:20} {confidence_str:12}") |
|
|
|
|
|
print("\nBatch processing complete!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in batch prediction: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("DEMO SUMMARY") |
|
|
print("-" * 40) |
|
|
|
|
|
print(f"✓ Model loaded successfully on: {classifier.device}") |
|
|
print(f"✓ Analyzed {len(available_images)} example image(s)") |
|
|
print(f"✓ Example categories: {classifier.CATEGORIES}") |
|
|
|
|
|
missing_images = [img for img in DEMO_IMAGES if img not in available_images] |
|
|
if missing_images: |
|
|
print(f"\nNote: Missing example images: {missing_images}") |
|
|
print("To complete the demo, add these images to the directory.") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Demo completed successfully!") |
|
|
print("\nNext steps:") |
|
|
print("1. Try your own images by modifying the DEMO_IMAGES list") |
|
|
print("2. Use the classifier in your own code:") |
|
|
print(" ```python") |
|
|
print(" from example_usage import CoherenceClassifier") |
|
|
print(" classifier = CoherenceClassifier('your_model.pth')") |
|
|
print(" result = classifier.predict('your_image.jpg')") |
|
|
print(" ```") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|