Spaces:
Sleeping
Sleeping
| # # src/classifier.py | |
| # from sentence_transformers import SentenceTransformer | |
| # import numpy as np | |
| # import pickle | |
| # class ProductClassifier: | |
| # def __init__(self, model_path="./models"): | |
| # self.model = SentenceTransformer("all-mpnet-base-v2") | |
| # self.embeddings = np.load(f"{model_path}/category_embeddings_mpnet.npy") | |
| # with open(f"{model_path}/category_metadata.pkl", "rb") as f: | |
| # self.metadata = pickle.load(f) | |
| # def classify(self, product_data, top_k=5): | |
| # # Implementation here | |
| # pass | |
| # """ | |
| # Product Classification Engine | |
| # Loads pre-trained embeddings and performs similarity-based classification | |
| # """ | |
| import numpy as np | |
| import pickle | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from typing import Dict, List, Optional | |
| import re | |
| import logging | |
| from .config import ( | |
| MODEL_NAME, | |
| EMBEDDINGS_FILE, | |
| METADATA_FILE, | |
| AUTO_APPROVE_THRESHOLD, | |
| QUICK_REVIEW_THRESHOLD, | |
| BOOST_FACTOR, | |
| MAX_BOOST, | |
| DEFAULT_TOP_K, | |
| PRODUCT_KEYWORDS, | |
| ) | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ProductClassifier: | |
| """ | |
| ML-powered product classifier for insurance categorization | |
| """ | |
| def __init__(self): | |
| """Initialize classifier by loading model and embeddings""" | |
| logger.info("Initializing Product Classifier...") | |
| # Load the embedding model | |
| logger.info(f"Loading model: {MODEL_NAME}") | |
| self.model = SentenceTransformer(MODEL_NAME) | |
| logger.info( | |
| f"β Model loaded (dimension: {self.model.get_sentence_embedding_dimension()})" | |
| ) | |
| # Load pre-computed category embeddings | |
| logger.info(f"Loading category embeddings from {EMBEDDINGS_FILE}") | |
| self.embeddings = np.load(EMBEDDINGS_FILE) | |
| logger.info(f"β Loaded {self.embeddings.shape[0]:,} category embeddings") | |
| # Load category metadata | |
| logger.info(f"Loading metadata from {METADATA_FILE}") | |
| with open(METADATA_FILE, "rb") as f: | |
| self.metadata = pickle.load(f) | |
| logger.info(f"β Metadata loaded") | |
| # Cache for processed texts | |
| self.embedding_texts = self.metadata.get("embedding_texts", []) | |
| logger.info("π Classifier ready!") | |
| def preprocess_product(self, product_data: Dict) -> str: | |
| """ | |
| Preprocess product data into searchable text | |
| Args: | |
| product_data: Dictionary with product fields | |
| - title (str): Product title | |
| - product_type (str, optional): Product type/category | |
| - vendor (str, optional): Brand/vendor name | |
| - tags (list/str, optional): Product tags | |
| - description (str, optional): Product description | |
| Returns: | |
| Processed text string for embedding | |
| """ | |
| parts = [] | |
| # Extract fields in priority order | |
| title = product_data.get("title", "") | |
| product_type = product_data.get("product_type", "") | |
| vendor = product_data.get("vendor", "") | |
| description = product_data.get("description", "") | |
| tags = product_data.get("tags", []) | |
| # 1. Title (most important) | |
| if title: | |
| parts.append(title) | |
| # 2. Product type (category hint) | |
| if product_type: | |
| parts.append(f"Product type: {product_type}") | |
| # 3. Brand/Vendor | |
| if vendor: | |
| parts.append(f"Brand: {vendor}") | |
| # 4. Tags (keywords) | |
| if tags: | |
| tag_text = " ".join(tags) if isinstance(tags, list) else tags | |
| parts.append(f"Keywords: {tag_text}") | |
| # 5. Description (limited to 100 chars) | |
| if description: | |
| desc_short = description[:100].strip() | |
| parts.append(desc_short) | |
| return ". ".join(parts) | |
| def extract_keywords(self, text: str) -> List[str]: | |
| """ | |
| Extract important keywords from product text | |
| Args: | |
| text: Product text | |
| Returns: | |
| List of detected keywords | |
| """ | |
| text_lower = text.lower() | |
| found_keywords = [kw for kw in PRODUCT_KEYWORDS if kw in text_lower] | |
| return found_keywords | |
| def classify( | |
| self, product_data: Dict, top_k: int = DEFAULT_TOP_K, use_boost: bool = True | |
| ) -> Dict: | |
| """ | |
| Classify a product into insurance categories | |
| Args: | |
| product_data: Product information dictionary | |
| top_k: Number of top matches to return | |
| use_boost: Whether to apply keyword boosting | |
| Returns: | |
| Classification results with confidence scores and recommendations | |
| """ | |
| # Preprocess product text | |
| product_text = self.preprocess_product(product_data) | |
| # Generate embedding for product | |
| product_embedding = self.model.encode([product_text], normalize_embeddings=True) | |
| # Calculate semantic similarities | |
| semantic_scores = cosine_similarity(product_embedding, self.embeddings)[0] | |
| # Apply keyword boosting if enabled | |
| if use_boost: | |
| product_keywords = self.extract_keywords(product_text) | |
| boosted_scores = self._apply_keyword_boost( | |
| semantic_scores, product_keywords | |
| ) | |
| else: | |
| boosted_scores = semantic_scores | |
| # Get top K indices | |
| top_indices = boosted_scores.argsort()[-top_k:][::-1] | |
| # Format results | |
| results = [] | |
| for rank, idx in enumerate(top_indices, 1): | |
| category_data = { | |
| "rank": rank, | |
| "category_id": self.metadata["category_ids"][idx], | |
| "category_path": self.metadata["category_paths"][idx], | |
| "semantic_score": float(semantic_scores[idx]), | |
| "final_score": float(boosted_scores[idx]), | |
| "confidence_percentage": round(float(boosted_scores[idx]) * 100, 2), | |
| } | |
| # Add boost information if used | |
| if use_boost: | |
| category_data["boost_applied"] = round( | |
| (boosted_scores[idx] - semantic_scores[idx]) * 100, 2 | |
| ) | |
| results.append(category_data) | |
| # Determine action based on top score | |
| top_confidence = results[0]["final_score"] | |
| if top_confidence >= AUTO_APPROVE_THRESHOLD: | |
| action = "AUTO_APPROVE" | |
| reason = f"High confidence ({results[0]['confidence_percentage']}%)" | |
| elif top_confidence >= QUICK_REVIEW_THRESHOLD: | |
| action = "QUICK_REVIEW" | |
| reason = f"Medium confidence ({results[0]['confidence_percentage']}%) - verify category" | |
| else: | |
| action = "MANUAL_CATEGORIZATION" | |
| reason = f"Low confidence ({results[0]['confidence_percentage']}%) - needs expert review" | |
| return { | |
| "product_id": product_data.get("id", "unknown"), | |
| "product_text": product_text, | |
| "action": action, | |
| "reason": reason, | |
| "top_category": results[0]["category_path"], | |
| "top_confidence": results[0]["confidence_percentage"], | |
| "alternatives": results[1:3] if len(results) > 1 else [], | |
| "all_results": results, | |
| } | |
| def _apply_keyword_boost( | |
| self, scores: np.ndarray, product_keywords: List[str] | |
| ) -> np.ndarray: | |
| """ | |
| Apply keyword-based score boosting | |
| Args: | |
| scores: Original semantic similarity scores | |
| product_keywords: List of keywords found in product | |
| Returns: | |
| Boosted scores | |
| """ | |
| boosted_scores = scores.copy() | |
| if not product_keywords: | |
| return boosted_scores | |
| # Boost categories that contain product keywords | |
| for idx, cat_text in enumerate(self.embedding_texts): | |
| cat_text_lower = cat_text.lower() | |
| matches = sum(1 for kw in product_keywords if kw in cat_text_lower) | |
| if matches > 0: | |
| # Boost proportional to keyword matches | |
| boost = min(matches * BOOST_FACTOR, MAX_BOOST) | |
| boosted_scores[idx] = min(boosted_scores[idx] + boost, 1.0) | |
| return boosted_scores | |
| def classify_batch( | |
| self, products: List[Dict], top_k: int = DEFAULT_TOP_K | |
| ) -> List[Dict]: | |
| """ | |
| Classify multiple products at once | |
| Args: | |
| products: List of product data dictionaries | |
| top_k: Number of top matches per product | |
| Returns: | |
| List of classification results | |
| """ | |
| logger.info(f"Classifying batch of {len(products)} products...") | |
| results = [] | |
| for i, product in enumerate(products, 1): | |
| try: | |
| result = self.classify(product, top_k=top_k) | |
| # Convert all numpy types to Python native types for JSON serialization | |
| result = self._convert_to_json_serializable(result) | |
| results.append(result) | |
| if i % 100 == 0: | |
| logger.info(f" Processed {i}/{len(products)} products") | |
| except Exception as e: | |
| logger.error(f" Error classifying product {i}: {e}") | |
| results.append( | |
| { | |
| "product_id": product.get("id", f"product_{i}"), | |
| "action": "ERROR", | |
| "reason": str(e), | |
| "top_category": None, | |
| "top_confidence": 0.0, | |
| } | |
| ) | |
| logger.info(f"β Batch classification complete!") | |
| return results | |
| def _convert_to_json_serializable(self, obj): | |
| """ | |
| Recursively convert numpy types to Python native types | |
| """ | |
| import numpy as np | |
| if isinstance(obj, dict): | |
| return { | |
| key: self._convert_to_json_serializable(value) | |
| for key, value in obj.items() | |
| } | |
| elif isinstance(obj, list): | |
| return [self._convert_to_json_serializable(item) for item in obj] | |
| elif isinstance(obj, (np.integer, np.int64, np.int32)): | |
| return int(obj) | |
| elif isinstance(obj, (np.floating, np.float64, np.float32)): | |
| return float(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| else: | |
| return obj | |
| # Test the classifier if run directly | |
| if __name__ == "__main__": | |
| print("Testing Product Classifier...") | |
| print("=" * 80) | |
| # Initialize classifier | |
| classifier = ProductClassifier() | |
| # Test product | |
| test_product = { | |
| "id": "test_001", | |
| "title": "Apple iPhone 15 Pro Max", | |
| "product_type": "Smartphone", | |
| "vendor": "Apple Inc", | |
| "tags": ["electronics", "mobile", "phone", "smartphone"], | |
| "description": "Latest flagship smartphone with titanium design", | |
| } | |
| print("\nπ± Test Product:") | |
| print(f" {test_product['title']}") | |
| # Classify | |
| result = classifier.classify(test_product) | |
| print(f"\nπ― Classification Result:") | |
| print(f" Action: {result['action']}") | |
| print(f" Top Category: {result['top_category']}") | |
| print(f" Confidence: {result['top_confidence']}%") | |
| print(f" Reason: {result['reason']}") | |
| print("\nπ Top 3 Alternatives:") | |
| for alt in result["alternatives"][:3]: | |
| print( | |
| f" {alt['rank']}. {alt['category_path']} ({alt['confidence_percentage']}%)" | |
| ) | |
| print("\n" + "=" * 80) | |
| print("β Classifier test complete!") | |