Product-Intelligence / models /classifier.py
Keramo's picture
Create models/classifier.py
2d16330 verified
Raw
History Blame Contribute Delete
2.57 kB
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import io
class AttributeClassifier:
def __init__(self):
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.categories = ["clothing", "footwear", "furniture", "electronics", "accessories", "general"]
# Attribute lists per category
self.attributes = {
"clothing": ["short sleeve", "long sleeve", "v-neck", "round neck", "solid", "striped", "button", "zipper"],
"furniture": ["wood", "metal", "glass", "fabric", "modern", "vintage", "minimalist"],
"footwear": ["sneaker", "boot", "loafer", "leather", "canvas"],
"electronics": ["smartphone", "laptop", "tablet", "black", "silver"],
"accessories": ["bag", "watch", "belt", "leather", "metal"]
}
def predict(self, image_bytes):
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# First predict category
cat_inputs = self.processor(text=self.categories, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
cat_outputs = self.model(**cat_inputs)
cat_logits = cat_outputs.logits_per_image[0]
cat_probs = torch.softmax(cat_logits, dim=-1)
best_cat_idx = cat_probs.argmax().item()
best_category = self.categories[best_cat_idx]
category_confidence = cat_probs[best_cat_idx].item()
# Predict attributes for that category (if any)
attr_list = self.attributes.get(best_category, [])
attr_scores = {}
if attr_list:
attr_inputs = self.processor(text=attr_list, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
attr_outputs = self.model(**attr_inputs)
attr_logits = attr_outputs.logits_per_image[0]
attr_probs = torch.softmax(attr_logits, dim=-1)
attr_scores = {attr: prob.item() for attr, prob in zip(attr_list, attr_probs)}
return {
"category": best_category,
"category_confidence": category_confidence,
"attributes": attr_scores,
"all_category_probs": {cat: prob.item() for cat, prob in zip(self.categories, cat_probs)}
}
# Singleton
_classifier = None
def get_classifier():
global _classifier
if _classifier is None:
_classifier = AttributeClassifier()
return _classifier