Spaces:
Runtime error
Runtime error
| """OpenCLIP Waste Classifier using ViT-B-16 model.""" | |
| import os | |
| import torch | |
| import open_clip | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from PIL import Image | |
| class OpenCLIPWasteClassifier: | |
| """Waste classifier using ViT-B-16 LAION model.""" | |
| def __init__(self): | |
| """Initialize classifier with pre-saved ViT-B-16 model.""" | |
| self.model_name = "ViT-B-16" | |
| self.pretrained = "laion2b_s34b_b88k" | |
| # Force CPU for HF Spaces compatibility | |
| self.device = "cpu" | |
| print(f"🚀 Loading ViT-B-16 OpenCLIP model on {self.device}...") | |
| try: | |
| self._load_model() | |
| self._load_database() | |
| self._create_item_embeddings() | |
| print("✅ Classifier ready!") | |
| except Exception as e: | |
| print(f"❌ Error initializing classifier: {e}") | |
| raise | |
| def _load_model(self): | |
| """Load pre-saved model.""" | |
| model_path = Path("models") / f"{self.model_name}_{self.pretrained.replace('_', '-')}_model.pth" | |
| print(f"📁 Looking for model at: {model_path}") | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f"Model not found at {model_path}") | |
| print(f"📂 Model file size: {model_path.stat().st_size / (1024*1024):.1f} MB") | |
| try: | |
| print("🔄 Loading model weights...") | |
| saved_data = torch.load(model_path, map_location='cpu') | |
| print("🏗️ Creating model architecture...") | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| saved_data['model_name'], | |
| pretrained=None | |
| ) | |
| print("⚡ Loading model state...") | |
| self.model.load_state_dict(saved_data['model_state_dict']) | |
| self.model = self.model.to(self.device).eval() | |
| self.tokenizer = open_clip.get_tokenizer(self.model_name) | |
| print("🔤 Tokenizer ready") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| raise | |
| def _load_database(self): | |
| """Load waste database.""" | |
| print("📊 Loading waste database...") | |
| if not os.path.exists("database.csv"): | |
| raise FileNotFoundError("Database not found at database.csv") | |
| self.df = pd.read_csv("database.csv") | |
| print(f"📊 Loaded {len(self.df)} items from database") | |
| def _create_item_embeddings(self): | |
| """Create embeddings for all items in database.""" | |
| print("🔗 Creating item embeddings...") | |
| # Simple text descriptions for each item | |
| item_texts = [f"a photo of {item}" for item in self.df['Item']] | |
| # Create embeddings | |
| item_tokens = self.tokenizer(item_texts).to(self.device) | |
| with torch.no_grad(): | |
| self.item_embeddings = self.model.encode_text(item_tokens) | |
| self.item_embeddings = self.item_embeddings / self.item_embeddings.norm(dim=-1, keepdim=True) | |
| print(f"✅ Created embeddings for {len(self.item_embeddings)} items") | |
| def classify_image(self, image_path_or_pil, top_k=5): | |
| """Classify waste item from image.""" | |
| try: | |
| # Handle image input | |
| if isinstance(image_path_or_pil, str): | |
| if not os.path.exists(image_path_or_pil): | |
| return {"error": f"Image file not found: {image_path_or_pil}"} | |
| image = Image.open(image_path_or_pil).convert('RGB') | |
| else: | |
| image = image_path_or_pil.convert('RGB') | |
| # Get image embedding | |
| image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_embedding = self.model.encode_image(image_tensor) | |
| image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True) | |
| # Compare with all items | |
| similarities = (image_embedding @ self.item_embeddings.T).cpu().numpy()[0] | |
| # Get top matches | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| results = [] | |
| for idx in top_indices: | |
| row = self.df.iloc[idx] | |
| similarity_score = float(similarities[idx]) | |
| # Get disposal instructions directly from CSV | |
| disposal_parts = [] | |
| for col in ['Instruction_1', 'Instruction_2', 'Instruction_3']: | |
| if pd.notna(row[col]) and row[col].strip(): | |
| disposal_parts.append(row[col].strip()) | |
| disposal_method = ' '.join(disposal_parts) if disposal_parts else "No instructions available" | |
| results.append({ | |
| 'item': row['Item'], | |
| 'category': row['Category'], | |
| 'disposal_method': disposal_method, | |
| 'confidence': similarity_score | |
| }) | |
| # Return results | |
| best_match = results[0] if results else None | |
| return { | |
| 'predicted_item': best_match['item'] if best_match else "Unknown", | |
| 'predicted_category': best_match['category'] if best_match else "Unknown", | |
| 'best_confidence': best_match['confidence'] if best_match else 0.0, | |
| 'top_items': results | |
| } | |
| except Exception as e: | |
| return {"error": f"Classification error: {str(e)}"} |