opeCLIP-waste-wizard / clip_waste_classifier /finetuned_classifier.py
ysfad's picture
Fix container permissions and implement fluid OpenCLIP matching
8d41103
"""Finetuned CLIP Waste Classifier using ViT-B-16 model."""
import os
import torch
import open_clip
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import json
import urllib.request
import urllib.error
import tempfile
class FinetunedCLIPWasteClassifier:
"""Waste classifier using finetuned ViT-B-16 model with fluid OpenCLIP matching."""
def __init__(self, model_path=None, hf_model_id=None):
"""Initialize classifier with finetuned model."""
self.device = "cpu" # Force CPU for consistency
# Use writable cache directories for containers
self.cache_dir = os.environ.get('HF_CACHE_DIR', '/tmp/hf_cache')
os.makedirs(self.cache_dir, exist_ok=True)
# Model source priority: local file -> HF Hub -> fallback to pretrained
self.model_path = model_path or "models_finetuned/best_clip_finetuned_vit-b-16.pth"
self.hf_model_id = hf_model_id # e.g., "username/waste-clip-finetuned"
print(f"🚀 Loading CLIP waste classifier...")
try:
if self._try_load_finetuned_model():
self._load_database()
self._create_item_embeddings() # Use database items, not fixed classes
print("✅ Finetuned classifier ready!")
else:
print("🔄 Falling back to pretrained model...")
self._load_pretrained_fallback()
except Exception as e:
print(f"❌ Error initializing classifier: {e}")
print("🔄 Falling back to pretrained model...")
self._load_pretrained_fallback()
def _try_load_finetuned_model(self):
"""Try to load finetuned model from various sources."""
# Try local file first
if os.path.exists(self.model_path):
print(f"📁 Found local model at {self.model_path}")
self._load_finetuned_model_file(self.model_path)
return True
# Try downloading from Hugging Face Hub
if self.hf_model_id:
print(f"🤗 Trying to download from Hugging Face: {self.hf_model_id}")
if self._download_from_hf_hub():
self._load_finetuned_model_file(self.model_path)
return True
return False
def _download_from_hf_hub(self):
"""Download model from Hugging Face Hub."""
try:
from huggingface_hub import hf_hub_download
model_file = hf_hub_download(
repo_id=self.hf_model_id,
filename="best_clip_finetuned_vit-b-16.pth",
cache_dir=self.cache_dir
)
# Copy to expected location
os.makedirs("/tmp/models_finetuned", exist_ok=True)
import shutil
temp_model_path = "/tmp/models_finetuned/best_clip_finetuned_vit-b-16.pth"
shutil.copy(model_file, temp_model_path)
self.model_path = temp_model_path
print(f"✅ Downloaded model from Hugging Face Hub")
return True
except ImportError:
print("❌ huggingface_hub not installed")
return False
except Exception as e:
print(f"❌ Failed to download from HF Hub: {e}")
return False
def _load_finetuned_model_file(self, model_path):
"""Load the finetuned model from file."""
print(f"📂 Model file size: {Path(model_path).stat().st_size / (1024*1024*1024):.1f} GB")
# Load saved model data
print("🔄 Loading model checkpoint...")
checkpoint = torch.load(model_path, map_location='cpu')
self.model_name = checkpoint['model_name']
self.pretrained = checkpoint['pretrained']
print(f"🏗️ Creating model architecture...")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
self.model_name, pretrained=None, cache_dir=self.cache_dir
)
# Load finetuned weights
print("⚡ Loading finetuned weights...")
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device).eval()
# Get tokenizer
self.tokenizer = open_clip.get_tokenizer(self.model_name)
print(f"🎯 Model validation accuracy: {checkpoint.get('val_accuracy', 'Unknown')}")
print("🔄 Using fluid OpenCLIP matching against database items")
def _load_pretrained_fallback(self):
"""Fallback to pretrained model if finetuned model fails."""
print("🔄 Loading pretrained ViT-B-16 model...")
self.model_name = "ViT-B-16"
self.pretrained = "laion2b_s34b_b88k"
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
self.model_name, pretrained=self.pretrained, cache_dir=self.cache_dir
)
self.model = self.model.to(self.device).eval()
self.tokenizer = open_clip.get_tokenizer(self.model_name)
self._load_database()
self._create_item_embeddings()
def _load_database(self):
"""Load waste database."""
print("📊 Loading waste database...")
if not os.path.exists("database.csv"):
raise FileNotFoundError("Database not found at database.csv")
self.df = pd.read_csv("database.csv")
print(f"📊 Loaded {len(self.df)} items from database")
def _create_item_embeddings(self):
"""Create embeddings for all items in database (fluid matching)."""
print("🔗 Creating item embeddings for fluid matching...")
# Simple text descriptions for each item (same as original OpenCLIP approach)
item_texts = [f"a photo of {item}" for item in self.df['Item']]
# Create embeddings using finetuned encoder
item_tokens = self.tokenizer(item_texts).to(self.device)
with torch.no_grad():
self.item_embeddings = self.model.encode_text(item_tokens)
self.item_embeddings = self.item_embeddings / self.item_embeddings.norm(dim=-1, keepdim=True)
print(f"✅ Created embeddings for {len(self.item_embeddings)} database items")
def classify_image(self, image_path_or_pil, top_k=5):
"""Classify waste item from image using fluid OpenCLIP matching."""
try:
# Handle image input
if isinstance(image_path_or_pil, str):
if not os.path.exists(image_path_or_pil):
return {"error": f"Image file not found: {image_path_or_pil}"}
image = Image.open(image_path_or_pil).convert('RGB')
else:
image = image_path_or_pil.convert('RGB')
# Get image embedding using finetuned encoder
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embedding = self.model.encode_image(image_tensor)
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
# Compare with all database items (fluid matching)
similarities = (image_embedding @ self.item_embeddings.T).cpu().numpy()[0]
# Get top matches
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for idx in top_indices:
row = self.df.iloc[idx]
similarity_score = float(similarities[idx])
# Get disposal instructions directly from CSV
disposal_parts = []
for col in ['Instruction_1', 'Instruction_2', 'Instruction_3']:
if pd.notna(row[col]) and row[col].strip():
disposal_parts.append(row[col].strip())
disposal_method = ' '.join(disposal_parts) if disposal_parts else "No instructions available"
results.append({
'item': row['Item'],
'category': row['Category'],
'disposal_method': disposal_method,
'confidence': similarity_score
})
# Return results
best_match = results[0] if results else None
# Determine model type
model_type = 'finetuned' if hasattr(self, 'model_path') and 'finetuned' in str(self.model_path) else 'pretrained'
return {
'predicted_item': best_match['item'] if best_match else "Unknown",
'predicted_category': best_match['category'] if best_match else "Unknown",
'best_confidence': best_match['confidence'] if best_match else 0.0,
'top_items': results,
'model_type': model_type
}
except Exception as e:
return {"error": f"Classification error: {str(e)}"}
def get_model_info(self):
"""Get information about the loaded model."""
model_type = 'finetuned' if hasattr(self, 'model_path') and 'finetuned' in str(self.model_path) else 'pretrained'
return {
'model_name': self.model_name,
'pretrained': getattr(self, 'pretrained', 'Unknown'),
'num_classes': len(self.df) if hasattr(self, 'df') else 0,
'classes': list(self.df['Item']) if hasattr(self, 'df') else [],
'model_path': getattr(self, 'model_path', 'Unknown'),
'device': self.device,
'model_type': model_type,
'matching_type': 'fluid_database'
}