opeCLIP-waste-wizard / clip_waste_classifier /openclip_classifier.py
ysfad's picture
\CLEAN: Back to basics - Pure OpenCLIP + CSV, no hardcoded logic, simple item matching"
7a62a28
"""OpenCLIP Waste Classifier using ViT-B-16 model."""
import os
import torch
import open_clip
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
class OpenCLIPWasteClassifier:
"""Waste classifier using ViT-B-16 LAION model."""
def __init__(self):
"""Initialize classifier with pre-saved ViT-B-16 model."""
self.model_name = "ViT-B-16"
self.pretrained = "laion2b_s34b_b88k"
# Force CPU for HF Spaces compatibility
self.device = "cpu"
print(f"🚀 Loading ViT-B-16 OpenCLIP model on {self.device}...")
try:
self._load_model()
self._load_database()
self._create_item_embeddings()
print("✅ Classifier ready!")
except Exception as e:
print(f"❌ Error initializing classifier: {e}")
raise
def _load_model(self):
"""Load pre-saved model."""
model_path = Path("models") / f"{self.model_name}_{self.pretrained.replace('_', '-')}_model.pth"
print(f"📁 Looking for model at: {model_path}")
if not model_path.exists():
raise FileNotFoundError(f"Model not found at {model_path}")
print(f"📂 Model file size: {model_path.stat().st_size / (1024*1024):.1f} MB")
try:
print("🔄 Loading model weights...")
saved_data = torch.load(model_path, map_location='cpu')
print("🏗️ Creating model architecture...")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
saved_data['model_name'],
pretrained=None
)
print("⚡ Loading model state...")
self.model.load_state_dict(saved_data['model_state_dict'])
self.model = self.model.to(self.device).eval()
self.tokenizer = open_clip.get_tokenizer(self.model_name)
print("🔤 Tokenizer ready")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def _load_database(self):
"""Load waste database."""
print("📊 Loading waste database...")
if not os.path.exists("database.csv"):
raise FileNotFoundError("Database not found at database.csv")
self.df = pd.read_csv("database.csv")
print(f"📊 Loaded {len(self.df)} items from database")
def _create_item_embeddings(self):
"""Create embeddings for all items in database."""
print("🔗 Creating item embeddings...")
# Simple text descriptions for each item
item_texts = [f"a photo of {item}" for item in self.df['Item']]
# Create embeddings
item_tokens = self.tokenizer(item_texts).to(self.device)
with torch.no_grad():
self.item_embeddings = self.model.encode_text(item_tokens)
self.item_embeddings = self.item_embeddings / self.item_embeddings.norm(dim=-1, keepdim=True)
print(f"✅ Created embeddings for {len(self.item_embeddings)} items")
def classify_image(self, image_path_or_pil, top_k=5):
"""Classify waste item from image."""
try:
# Handle image input
if isinstance(image_path_or_pil, str):
if not os.path.exists(image_path_or_pil):
return {"error": f"Image file not found: {image_path_or_pil}"}
image = Image.open(image_path_or_pil).convert('RGB')
else:
image = image_path_or_pil.convert('RGB')
# Get image embedding
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embedding = self.model.encode_image(image_tensor)
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
# Compare with all items
similarities = (image_embedding @ self.item_embeddings.T).cpu().numpy()[0]
# Get top matches
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for idx in top_indices:
row = self.df.iloc[idx]
similarity_score = float(similarities[idx])
# Get disposal instructions directly from CSV
disposal_parts = []
for col in ['Instruction_1', 'Instruction_2', 'Instruction_3']:
if pd.notna(row[col]) and row[col].strip():
disposal_parts.append(row[col].strip())
disposal_method = ' '.join(disposal_parts) if disposal_parts else "No instructions available"
results.append({
'item': row['Item'],
'category': row['Category'],
'disposal_method': disposal_method,
'confidence': similarity_score
})
# Return results
best_match = results[0] if results else None
return {
'predicted_item': best_match['item'] if best_match else "Unknown",
'predicted_category': best_match['category'] if best_match else "Unknown",
'best_confidence': best_match['confidence'] if best_match else 0.0,
'top_items': results
}
except Exception as e:
return {"error": f"Classification error: {str(e)}"}