from transformers import CLIPProcessor, CLIPModel import torch from PIL import Image import io class AttributeClassifier: def __init__(self): self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") self.categories = ["clothing", "footwear", "furniture", "electronics", "accessories", "general"] # Attribute lists per category self.attributes = { "clothing": ["short sleeve", "long sleeve", "v-neck", "round neck", "solid", "striped", "button", "zipper"], "furniture": ["wood", "metal", "glass", "fabric", "modern", "vintage", "minimalist"], "footwear": ["sneaker", "boot", "loafer", "leather", "canvas"], "electronics": ["smartphone", "laptop", "tablet", "black", "silver"], "accessories": ["bag", "watch", "belt", "leather", "metal"] } def predict(self, image_bytes): image = Image.open(io.BytesIO(image_bytes)).convert('RGB') # First predict category cat_inputs = self.processor(text=self.categories, images=image, return_tensors="pt", padding=True) with torch.no_grad(): cat_outputs = self.model(**cat_inputs) cat_logits = cat_outputs.logits_per_image[0] cat_probs = torch.softmax(cat_logits, dim=-1) best_cat_idx = cat_probs.argmax().item() best_category = self.categories[best_cat_idx] category_confidence = cat_probs[best_cat_idx].item() # Predict attributes for that category (if any) attr_list = self.attributes.get(best_category, []) attr_scores = {} if attr_list: attr_inputs = self.processor(text=attr_list, images=image, return_tensors="pt", padding=True) with torch.no_grad(): attr_outputs = self.model(**attr_inputs) attr_logits = attr_outputs.logits_per_image[0] attr_probs = torch.softmax(attr_logits, dim=-1) attr_scores = {attr: prob.item() for attr, prob in zip(attr_list, attr_probs)} return { "category": best_category, "category_confidence": category_confidence, "attributes": attr_scores, "all_category_probs": {cat: prob.item() for cat, prob in zip(self.categories, cat_probs)} } # Singleton _classifier = None def get_classifier(): global _classifier if _classifier is None: _classifier = AttributeClassifier() return _classifier