Spaces:
Runtime error
Runtime error
| """Finetuned CLIP Waste Classifier using ViT-B-16 model.""" | |
| import os | |
| import torch | |
| import open_clip | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from PIL import Image | |
| import json | |
| import urllib.request | |
| import urllib.error | |
| import tempfile | |
| class FinetunedCLIPWasteClassifier: | |
| """Waste classifier using finetuned ViT-B-16 model with fluid OpenCLIP matching.""" | |
| def __init__(self, model_path=None, hf_model_id=None): | |
| """Initialize classifier with finetuned model.""" | |
| self.device = "cpu" # Force CPU for consistency | |
| # Use writable cache directories for containers | |
| self.cache_dir = os.environ.get('HF_CACHE_DIR', '/tmp/hf_cache') | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| # Model source priority: local file -> HF Hub -> fallback to pretrained | |
| self.model_path = model_path or "models_finetuned/best_clip_finetuned_vit-b-16.pth" | |
| self.hf_model_id = hf_model_id # e.g., "username/waste-clip-finetuned" | |
| print(f"🚀 Loading CLIP waste classifier...") | |
| try: | |
| if self._try_load_finetuned_model(): | |
| self._load_database() | |
| self._create_item_embeddings() # Use database items, not fixed classes | |
| print("✅ Finetuned classifier ready!") | |
| else: | |
| print("🔄 Falling back to pretrained model...") | |
| self._load_pretrained_fallback() | |
| except Exception as e: | |
| print(f"❌ Error initializing classifier: {e}") | |
| print("🔄 Falling back to pretrained model...") | |
| self._load_pretrained_fallback() | |
| def _try_load_finetuned_model(self): | |
| """Try to load finetuned model from various sources.""" | |
| # Try local file first | |
| if os.path.exists(self.model_path): | |
| print(f"📁 Found local model at {self.model_path}") | |
| self._load_finetuned_model_file(self.model_path) | |
| return True | |
| # Try downloading from Hugging Face Hub | |
| if self.hf_model_id: | |
| print(f"🤗 Trying to download from Hugging Face: {self.hf_model_id}") | |
| if self._download_from_hf_hub(): | |
| self._load_finetuned_model_file(self.model_path) | |
| return True | |
| return False | |
| def _download_from_hf_hub(self): | |
| """Download model from Hugging Face Hub.""" | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| model_file = hf_hub_download( | |
| repo_id=self.hf_model_id, | |
| filename="best_clip_finetuned_vit-b-16.pth", | |
| cache_dir=self.cache_dir | |
| ) | |
| # Copy to expected location | |
| os.makedirs("/tmp/models_finetuned", exist_ok=True) | |
| import shutil | |
| temp_model_path = "/tmp/models_finetuned/best_clip_finetuned_vit-b-16.pth" | |
| shutil.copy(model_file, temp_model_path) | |
| self.model_path = temp_model_path | |
| print(f"✅ Downloaded model from Hugging Face Hub") | |
| return True | |
| except ImportError: | |
| print("❌ huggingface_hub not installed") | |
| return False | |
| except Exception as e: | |
| print(f"❌ Failed to download from HF Hub: {e}") | |
| return False | |
| def _load_finetuned_model_file(self, model_path): | |
| """Load the finetuned model from file.""" | |
| print(f"📂 Model file size: {Path(model_path).stat().st_size / (1024*1024*1024):.1f} GB") | |
| # Load saved model data | |
| print("🔄 Loading model checkpoint...") | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| self.model_name = checkpoint['model_name'] | |
| self.pretrained = checkpoint['pretrained'] | |
| print(f"🏗️ Creating model architecture...") | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| self.model_name, pretrained=None, cache_dir=self.cache_dir | |
| ) | |
| # Load finetuned weights | |
| print("⚡ Loading finetuned weights...") | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model = self.model.to(self.device).eval() | |
| # Get tokenizer | |
| self.tokenizer = open_clip.get_tokenizer(self.model_name) | |
| print(f"🎯 Model validation accuracy: {checkpoint.get('val_accuracy', 'Unknown')}") | |
| print("🔄 Using fluid OpenCLIP matching against database items") | |
| def _load_pretrained_fallback(self): | |
| """Fallback to pretrained model if finetuned model fails.""" | |
| print("🔄 Loading pretrained ViT-B-16 model...") | |
| self.model_name = "ViT-B-16" | |
| self.pretrained = "laion2b_s34b_b88k" | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| self.model_name, pretrained=self.pretrained, cache_dir=self.cache_dir | |
| ) | |
| self.model = self.model.to(self.device).eval() | |
| self.tokenizer = open_clip.get_tokenizer(self.model_name) | |
| self._load_database() | |
| self._create_item_embeddings() | |
| def _load_database(self): | |
| """Load waste database.""" | |
| print("📊 Loading waste database...") | |
| if not os.path.exists("database.csv"): | |
| raise FileNotFoundError("Database not found at database.csv") | |
| self.df = pd.read_csv("database.csv") | |
| print(f"📊 Loaded {len(self.df)} items from database") | |
| def _create_item_embeddings(self): | |
| """Create embeddings for all items in database (fluid matching).""" | |
| print("🔗 Creating item embeddings for fluid matching...") | |
| # Simple text descriptions for each item (same as original OpenCLIP approach) | |
| item_texts = [f"a photo of {item}" for item in self.df['Item']] | |
| # Create embeddings using finetuned encoder | |
| item_tokens = self.tokenizer(item_texts).to(self.device) | |
| with torch.no_grad(): | |
| self.item_embeddings = self.model.encode_text(item_tokens) | |
| self.item_embeddings = self.item_embeddings / self.item_embeddings.norm(dim=-1, keepdim=True) | |
| print(f"✅ Created embeddings for {len(self.item_embeddings)} database items") | |
| def classify_image(self, image_path_or_pil, top_k=5): | |
| """Classify waste item from image using fluid OpenCLIP matching.""" | |
| try: | |
| # Handle image input | |
| if isinstance(image_path_or_pil, str): | |
| if not os.path.exists(image_path_or_pil): | |
| return {"error": f"Image file not found: {image_path_or_pil}"} | |
| image = Image.open(image_path_or_pil).convert('RGB') | |
| else: | |
| image = image_path_or_pil.convert('RGB') | |
| # Get image embedding using finetuned encoder | |
| image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_embedding = self.model.encode_image(image_tensor) | |
| image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True) | |
| # Compare with all database items (fluid matching) | |
| similarities = (image_embedding @ self.item_embeddings.T).cpu().numpy()[0] | |
| # Get top matches | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| results = [] | |
| for idx in top_indices: | |
| row = self.df.iloc[idx] | |
| similarity_score = float(similarities[idx]) | |
| # Get disposal instructions directly from CSV | |
| disposal_parts = [] | |
| for col in ['Instruction_1', 'Instruction_2', 'Instruction_3']: | |
| if pd.notna(row[col]) and row[col].strip(): | |
| disposal_parts.append(row[col].strip()) | |
| disposal_method = ' '.join(disposal_parts) if disposal_parts else "No instructions available" | |
| results.append({ | |
| 'item': row['Item'], | |
| 'category': row['Category'], | |
| 'disposal_method': disposal_method, | |
| 'confidence': similarity_score | |
| }) | |
| # Return results | |
| best_match = results[0] if results else None | |
| # Determine model type | |
| model_type = 'finetuned' if hasattr(self, 'model_path') and 'finetuned' in str(self.model_path) else 'pretrained' | |
| return { | |
| 'predicted_item': best_match['item'] if best_match else "Unknown", | |
| 'predicted_category': best_match['category'] if best_match else "Unknown", | |
| 'best_confidence': best_match['confidence'] if best_match else 0.0, | |
| 'top_items': results, | |
| 'model_type': model_type | |
| } | |
| except Exception as e: | |
| return {"error": f"Classification error: {str(e)}"} | |
| def get_model_info(self): | |
| """Get information about the loaded model.""" | |
| model_type = 'finetuned' if hasattr(self, 'model_path') and 'finetuned' in str(self.model_path) else 'pretrained' | |
| return { | |
| 'model_name': self.model_name, | |
| 'pretrained': getattr(self, 'pretrained', 'Unknown'), | |
| 'num_classes': len(self.df) if hasattr(self, 'df') else 0, | |
| 'classes': list(self.df['Item']) if hasattr(self, 'df') else [], | |
| 'model_path': getattr(self, 'model_path', 'Unknown'), | |
| 'device': self.device, | |
| 'model_type': model_type, | |
| 'matching_type': 'fluid_database' | |
| } |