Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Data loader for products from YAML file and OpenSearch indexing.""" | |
| import os | |
| import yaml | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional | |
| from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, helpers | |
| import boto3 | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # OpenSearch configuration from environment variables | |
| OPENSEARCH_ENDPOINT = os.getenv("OPENSEARCH_ENDPOINT", "") | |
| AWS_REGION = os.getenv("AWS_REGION", "us-east-1") | |
| def load_products_from_yaml(yaml_path: str = "../search_personalization/product_samples/products.yaml") -> List[Dict[str, Any]]: | |
| """ | |
| Load products from YAML file. | |
| Args: | |
| yaml_path: Path to the products YAML file (relative to src directory) | |
| Returns: | |
| List of product dictionaries | |
| """ | |
| # Get path relative to this file | |
| file_path = Path(__file__).parent / yaml_path | |
| if not file_path.exists(): | |
| raise FileNotFoundError(f"Products YAML file not found at: {file_path}") | |
| print(f"Loading products from: {file_path}") | |
| with open(file_path, 'r') as f: | |
| products = yaml.safe_load(f) | |
| return products if products else [] | |
| def get_sample_products(count: int = 10) -> List[Dict[str, Any]]: | |
| """ | |
| Get a sample of products from the YAML file. | |
| Args: | |
| count: Number of products to return | |
| Returns: | |
| List of product dictionaries | |
| """ | |
| products = load_products_from_yaml() | |
| return products[:count] | |
| _opensearch_client: Optional[OpenSearch] = None | |
| def get_opensearch_client(endpoint: str = OPENSEARCH_ENDPOINT) -> OpenSearch: | |
| """ | |
| Return a cached OpenSearch client singleton (created once, reused). | |
| Authenticates using AWS SigV4 signing with the IAM role resolved from | |
| the configured AWS_PROFILE (or default credential chain). | |
| Args: | |
| endpoint: OpenSearch cluster endpoint URL | |
| Returns: | |
| OpenSearch client instance | |
| """ | |
| global _opensearch_client | |
| if _opensearch_client is not None: | |
| return _opensearch_client | |
| # Extract host from endpoint URL | |
| host = endpoint.replace('https://', '').replace('http://', '') | |
| # Authenticate via AWS SigV4 signing using IAM role | |
| aws_profile = os.getenv("AWS_PROFILE") | |
| session = boto3.Session(profile_name=aws_profile) if aws_profile else boto3.Session() | |
| credentials = session.get_credentials() | |
| if not credentials: | |
| raise ValueError("AWS credentials not found") | |
| http_auth = AWSV4SignerAuth(credentials, AWS_REGION, 'es') | |
| # Create OpenSearch client | |
| _opensearch_client = OpenSearch( | |
| hosts=[{'host': host, 'port': 443}], | |
| http_auth=http_auth, | |
| use_ssl=True, | |
| verify_certs=True, | |
| connection_class=RequestsHttpConnection, | |
| timeout=30, | |
| max_retries=3, | |
| retry_on_timeout=True | |
| ) | |
| return _opensearch_client | |
| def create_products_index(client: OpenSearch, index_name: str = 'products') -> None: | |
| """ | |
| Create the products index with appropriate mappings. | |
| Args: | |
| client: OpenSearch client | |
| index_name: Name of the index to create | |
| """ | |
| ingest_pipeline = os.getenv('OPENSEARCH_INGEST_PIPELINE', 'product-multimodal-pipeline') | |
| index_body = { | |
| 'settings': { | |
| 'default_pipeline': ingest_pipeline, | |
| 'index.knn': True, | |
| 'number_of_shards': 1, | |
| 'number_of_replicas': 1 | |
| }, | |
| 'mappings': { | |
| 'properties': { | |
| 'id': {'type': 'keyword'}, | |
| 'name': {'type': 'text', 'analyzer': 'standard'}, | |
| 'description': {'type': 'text', 'analyzer': 'standard'}, | |
| 'price': {'type': 'float'}, | |
| 'category': {'type': 'keyword'}, | |
| 'style': {'type': 'keyword'}, | |
| 'current_stock': {'type': 'integer'}, | |
| 'image': {'type': 'keyword'}, | |
| 'gender_affinity': {'type': 'keyword'}, | |
| 'where_visible': {'type': 'keyword'}, | |
| 'promoted': {'type': 'boolean'}, | |
| 'product_description_vector': { | |
| 'type': 'knn_vector', | |
| 'dimension': 1024, | |
| 'method': { | |
| 'name': 'hnsw', | |
| 'space_type': 'cosinesimil', | |
| 'engine': 'faiss', | |
| 'parameters': { | |
| 'ef_construction': 128, | |
| 'm': 16 | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| # Delete index if it exists | |
| if client.indices.exists(index=index_name): | |
| print(f"Deleting existing index: {index_name}") | |
| client.indices.delete(index=index_name) | |
| # Create new index | |
| print(f"Creating index: {index_name}") | |
| client.indices.create(index=index_name, body=index_body) | |
| print(f"Index '{index_name}' created successfully") | |
| def _get_image_base64(product: Dict[str, Any]) -> Optional[str]: | |
| """Load product image as base64 string for multimodal embedding.""" | |
| import base64 | |
| category = product.get('category', '') | |
| image_filename = product.get('image', '') | |
| if not category or not image_filename: | |
| return None | |
| image_path = Path(__file__).parent.parent / 'product_samples' / category / image_filename | |
| if not image_path.exists(): | |
| return None | |
| with open(image_path, 'rb') as f: | |
| return base64.b64encode(f.read()).decode('utf-8') | |
| def index_products_to_opensearch( | |
| products: List[Dict[str, Any]], | |
| client: OpenSearch = None, | |
| index_name: str = 'products' | |
| ) -> tuple[int, int]: | |
| """ | |
| Index products into OpenSearch using bulk API. | |
| Args: | |
| products: List of product dictionaries | |
| client: OpenSearch client (creates new one if None) | |
| index_name: Name of the index | |
| Returns: | |
| Tuple of (success_count, failed_count) | |
| """ | |
| from opensearchpy import helpers | |
| if client is None: | |
| client = get_opensearch_client() | |
| # Prepare bulk actions with image_binary for multimodal embedding | |
| actions = [] | |
| for product in products: | |
| doc = dict(product) | |
| image_b64 = _get_image_base64(product) | |
| if image_b64: | |
| doc['image_binary'] = image_b64 | |
| action = { | |
| '_index': index_name, | |
| '_id': product['id'], | |
| '_source': doc | |
| } | |
| actions.append(action) | |
| # Bulk index | |
| print(f"Indexing {len(products)} products...") | |
| success, failed = helpers.bulk(client, actions, raise_on_error=False) | |
| print(f"Successfully indexed: {success} products") | |
| if failed: | |
| print(f"Failed to index: {len(failed)} products") | |
| # Refresh index to make documents searchable | |
| client.indices.refresh(index=index_name) | |
| print(f"Index '{index_name}' refreshed") | |
| return success, len(failed) if failed else 0 | |
| def load_yaml_to_opensearch( | |
| yaml_path: str = "../search_personalization/product_samples/products.yaml", | |
| endpoint: str = OPENSEARCH_ENDPOINT, | |
| index_name: str = 'products', | |
| recreate_index: bool = True | |
| ) -> None: | |
| """ | |
| Load products from YAML and index them into OpenSearch. | |
| Args: | |
| yaml_path: Path to the products YAML file | |
| endpoint: OpenSearch cluster endpoint URL | |
| index_name: Name of the index | |
| recreate_index: Whether to recreate the index (deletes existing) | |
| """ | |
| # Load products from YAML | |
| print("Loading products from YAML...") | |
| products = load_products_from_yaml(yaml_path) | |
| print(f"Loaded {len(products)} products") | |
| # Get OpenSearch client | |
| print(f"Connecting to OpenSearch at {endpoint}...") | |
| client = get_opensearch_client(endpoint) | |
| # Test connection | |
| info = client.info() | |
| print(f"Connected to cluster: {info['cluster_name']}") | |
| # Create index if needed | |
| if recreate_index: | |
| create_products_index(client, index_name) | |
| # Index products | |
| success, failed = index_products_to_opensearch(products, client, index_name) | |
| print(f"\nβ Indexing complete: {success} successful, {failed} failed") | |
| # --- Persona Index --- | |
| PERSONAS = [ | |
| { | |
| "user_id": "anonymous", | |
| "name": "Anonymous User", | |
| "age": 30, | |
| "gender": "unspecified", | |
| "location": "General", | |
| "occupation": "General User", | |
| "interests": ["General Shopping"], | |
| "shopping_behavior": "Standard search without personalization", | |
| "key": "anonymous", | |
| "emoji": "π€" | |
| }, | |
| { | |
| "name": "Sarah", | |
| "user_id": "user1", | |
| "age": 32, | |
| "gender": "female", | |
| "location": "San Francisco, CA", | |
| "occupation": "Marketing Manager", | |
| "interests": ["Technology", "Fitness", "Travel"], | |
| "shopping_behavior": "Focused on quality and trends.", | |
| "key": "business_professional", | |
| "emoji": "π©βπΌ" | |
| }, | |
| { | |
| "name": "Alex", | |
| "user_id": "user2", | |
| "age": 21, | |
| "gender": "male", | |
| "location": "Austin, TX", | |
| "occupation": "Computer Science Student", | |
| "interests": ["Gaming", "Programming", "Music"], | |
| "shopping_behavior": "Budget-conscious, looks for deals and discounts", | |
| "key": "student", | |
| "emoji": "π" | |
| } | |
| ] | |
| def create_personas_index(client: OpenSearch, index_name: str = 'personas') -> None: | |
| """Create the personas index with appropriate mappings.""" | |
| index_body = { | |
| 'settings': { | |
| 'number_of_shards': 1, | |
| 'number_of_replicas': 1 | |
| }, | |
| 'mappings': { | |
| 'properties': { | |
| 'user_id': {'type': 'keyword'}, | |
| 'name': {'type': 'text'}, | |
| 'age': {'type': 'integer'}, | |
| 'gender': {'type': 'keyword'}, | |
| 'location': {'type': 'text', 'fields': {'keyword': {'type': 'keyword'}}}, | |
| 'occupation': {'type': 'text', 'fields': {'keyword': {'type': 'keyword'}}}, | |
| 'interests': {'type': 'keyword'}, | |
| 'shopping_behavior': {'type': 'text'}, | |
| 'key': {'type': 'keyword'}, | |
| 'emoji': {'type': 'keyword'} | |
| } | |
| } | |
| } | |
| if client.indices.exists(index=index_name): | |
| print(f"Deleting existing index: {index_name}") | |
| client.indices.delete(index=index_name) | |
| print(f"Creating index: {index_name}") | |
| client.indices.create(index=index_name, body=index_body) | |
| print(f"Index '{index_name}' created successfully") | |
| def index_personas( | |
| client: OpenSearch = None, | |
| index_name: str = 'personas' | |
| ) -> tuple[int, int]: | |
| """Index persona documents into OpenSearch.""" | |
| if client is None: | |
| client = get_opensearch_client() | |
| actions = [ | |
| {'_index': index_name, '_id': p['user_id'], '_source': p} | |
| for p in PERSONAS | |
| ] | |
| print(f"Indexing {len(PERSONAS)} personas...") | |
| success, failed = helpers.bulk(client, actions, raise_on_error=False) | |
| print(f"Successfully indexed: {success} personas") | |
| if failed: | |
| print(f"Failed to index: {len(failed)} personas") | |
| client.indices.refresh(index=index_name) | |
| print(f"Index '{index_name}' refreshed") | |
| return success, len(failed) if failed else 0 | |
| def load_personas_to_opensearch( | |
| endpoint: str = OPENSEARCH_ENDPOINT, | |
| index_name: str = 'personas', | |
| recreate_index: bool = True | |
| ) -> None: | |
| """Create the personas index and load all persona documents.""" | |
| print(f"Connecting to OpenSearch at {endpoint}...") | |
| client = get_opensearch_client(endpoint) | |
| info = client.info() | |
| print(f"Connected to cluster: {info['cluster_name']}") | |
| if recreate_index: | |
| create_personas_index(client, index_name) | |
| success, failed = index_personas(client, index_name) | |
| print(f"\nβ Persona indexing complete: {success} successful, {failed} failed") | |
| if __name__ == '__main__': | |
| # Run the data loading when script is executed directly | |
| load_yaml_to_opensearch() | |