Spaces:
Sleeping
Sleeping
| from transformers import pipeline | |
| import torch | |
| import os | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from config import CANDIDATE_LABELS, IMAGE_SIZE | |
| class GenAILabeler: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Use a vision-language model for better image understanding | |
| self.classifier = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| # More specific candidate labels | |
| self.candidate_labels = CANDIDATE_LABELS | |
| def analyze_image_content(self, image_path): | |
| """Extract visual characteristics from image filename""" | |
| # In a real implementation, we'd use computer vision | |
| # For now, we'll create better prompts based on filenames | |
| filename = os.path.basename(image_path).lower() | |
| characteristics = [] | |
| if 'gold' in filename: | |
| characteristics.append("visible metallic particles, yellow coloration") | |
| if 'iron' in filename or 'pyrite' in filename: | |
| characteristics.append("dark metallic appearance, magnetic properties") | |
| if 'lithium' in filename or 'spodumene' in filename: | |
| characteristics.append("light-colored minerals, pegmatite texture") | |
| if 'copper' in filename: | |
| characteristics.append("green or blue coloration, metallic luster") | |
| if 'quartz' in filename: | |
| characteristics.append("clear or white crystalline structure") | |
| if 'granite' in filename: | |
| characteristics.append("mixed mineral composition, coarse-grained") | |
| if 'basalt' in filename: | |
| characteristics.append("dark fine-grained texture") | |
| if not characteristics: | |
| characteristics = ["visible mineral grains", "distinctive color patterns", "unique textural features"] | |
| return ", ".join(characteristics) | |
| def label_cluster(self, sample_image_path): | |
| """Generate label for a cluster based on a sample image""" | |
| # Get visual characteristics | |
| visual_features = self.analyze_image_content(sample_image_path) | |
| # Create a more specific prompt | |
| prompt = f"A geological drill core sample showing {visual_features}. " | |
| prompt += "What economically important mineral is most likely present in this rock sample?" | |
| # Perform zero-shot classification | |
| result = self.classifier(prompt, self.candidate_labels) | |
| # Return top prediction with all scores | |
| return { | |
| "label": result['labels'][0], | |
| "confidence": result['scores'][0], | |
| "all_scores": dict(zip(result['labels'], result['scores'])), | |
| "prompt_used": prompt | |
| } | |
| def label_all_clusters(self, cluster_map): | |
| """Label all clusters with improved context""" | |
| cluster_labels = {} | |
| print("Generating detailed labels for clusters using GenAI...") | |
| for cluster_id, image_paths in cluster_map.items(): | |
| # Use first image as sample for the cluster | |
| sample_path = image_paths[0] | |
| label_info = self.label_cluster(sample_path) | |
| cluster_labels[cluster_id] = label_info | |
| print(f"\nCluster {cluster_id}:") | |
| print(f" Primary Label: {label_info['label']}") | |
| print(f" Confidence: {label_info['confidence']:.3f}") | |
| print(f" Key Features: {self.analyze_image_content(sample_path)}") | |
| # Show top 3 alternative labels | |
| sorted_scores = sorted(label_info['all_scores'].items(), key=lambda x: x[1], reverse=True) | |
| print(" Alternative possibilities:") | |
| for label, score in sorted_scores[1:4]: | |
| print(f" - {label}: {score:.3f}") | |
| return cluster_labels | |
| if __name__ == "__main__": | |
| # This would be called from the main pipeline | |
| pass | |