Spaces:
Sleeping
Sleeping
| from transformers import CLIPProcessor, CLIPModel | |
| import torch | |
| from PIL import Image | |
| import io | |
| class AttributeClassifier: | |
| def __init__(self): | |
| self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| self.categories = ["clothing", "footwear", "furniture", "electronics", "accessories", "general"] | |
| # Attribute lists per category | |
| self.attributes = { | |
| "clothing": ["short sleeve", "long sleeve", "v-neck", "round neck", "solid", "striped", "button", "zipper"], | |
| "furniture": ["wood", "metal", "glass", "fabric", "modern", "vintage", "minimalist"], | |
| "footwear": ["sneaker", "boot", "loafer", "leather", "canvas"], | |
| "electronics": ["smartphone", "laptop", "tablet", "black", "silver"], | |
| "accessories": ["bag", "watch", "belt", "leather", "metal"] | |
| } | |
| def predict(self, image_bytes): | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| # First predict category | |
| cat_inputs = self.processor(text=self.categories, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| cat_outputs = self.model(**cat_inputs) | |
| cat_logits = cat_outputs.logits_per_image[0] | |
| cat_probs = torch.softmax(cat_logits, dim=-1) | |
| best_cat_idx = cat_probs.argmax().item() | |
| best_category = self.categories[best_cat_idx] | |
| category_confidence = cat_probs[best_cat_idx].item() | |
| # Predict attributes for that category (if any) | |
| attr_list = self.attributes.get(best_category, []) | |
| attr_scores = {} | |
| if attr_list: | |
| attr_inputs = self.processor(text=attr_list, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| attr_outputs = self.model(**attr_inputs) | |
| attr_logits = attr_outputs.logits_per_image[0] | |
| attr_probs = torch.softmax(attr_logits, dim=-1) | |
| attr_scores = {attr: prob.item() for attr, prob in zip(attr_list, attr_probs)} | |
| return { | |
| "category": best_category, | |
| "category_confidence": category_confidence, | |
| "attributes": attr_scores, | |
| "all_category_probs": {cat: prob.item() for cat, prob in zip(self.categories, cat_probs)} | |
| } | |
| # Singleton | |
| _classifier = None | |
| def get_classifier(): | |
| global _classifier | |
| if _classifier is None: | |
| _classifier = AttributeClassifier() | |
| return _classifier |