Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """ | |
| Strands @tool functions for agents to read/write OpenSearch agentic memory. | |
| These tools are shared across the Query Understanding and Ranking agents. | |
| """ | |
| import json | |
| import os | |
| from typing import Optional | |
| from strands import tool | |
| from search_personalization.agentic_memory.config import CONTAINER_ID | |
| from search_personalization.agentic_memory.memory_client import read_memory, write_memory | |
| # In-memory cache for rerank results (keyed by hash of query + documents) | |
| _rerank_cache: dict[str, str] = {} | |
| def _get_container_id() -> str: | |
| """Get the single memory container ID, raising if not configured.""" | |
| if not CONTAINER_ID: | |
| raise ValueError( | |
| "MEMORY_CONTAINER_ID not configured. Run setup or set the env var." | |
| ) | |
| return CONTAINER_ID | |
| def get_user_profile(persona_id: str) -> str: | |
| """Retrieve the user's long-term memory profile from OpenSearch agentic memory. | |
| Returns extracted preferences including color preferences, style, size, price sensitivity, | |
| and any explicit aversions learned from purchase history, returns, and reviews. | |
| Args: | |
| persona_id: The user ID (e.g., user1, user2). | |
| """ | |
| from search_personalization.agentic_memory.memory_client import _client | |
| container_id = _get_container_id() | |
| namespace = {"user_id": persona_id} | |
| # Query long-term memory filtered to USER_PREFERENCE strategy only | |
| client = _client() | |
| body = { | |
| "query": {"bool": {"must": [ | |
| {"term": {"namespace.user_id": persona_id}}, | |
| {"term": {"strategy_type": "USER_PREFERENCE"}}, | |
| ]}}, | |
| "size": 20, | |
| } | |
| result = client.transport.perform_request( | |
| "GET", | |
| f"/_plugins/_ml/memory_containers/{container_id}/memories/long-term/_search", | |
| body=body, | |
| ) | |
| memories = result.get("hits", {}).get("hits", []) | |
| trace = { | |
| "memory_reads": [ | |
| { | |
| "namespace": namespace, | |
| "type": "long-term", | |
| "strategy_filter": "USER_PREFERENCE", | |
| "records_retrieved": len(memories), | |
| } | |
| ] | |
| } | |
| return json.dumps( | |
| {"persona_id": persona_id, "memories": memories, "trace": trace}, | |
| indent=2, | |
| default=str, | |
| ) | |
| def read_session_memory(persona_id: str, session_id: str) -> str: | |
| """Retrieve current session context for multi-turn query understanding. | |
| Returns previous queries and interactions within this shopping session. | |
| Args: | |
| persona_id: The user ID (e.g., user1, user2). | |
| session_id: The current session identifier. | |
| """ | |
| container_id = _get_container_id() | |
| namespace = {"user_id": persona_id, "session_id": session_id} | |
| result = read_memory( | |
| container_id=container_id, | |
| namespace=namespace, | |
| memory_type="sessions", | |
| ) | |
| memories = result.get("hits", {}).get("hits", []) | |
| trace = { | |
| "memory_reads": [ | |
| { | |
| "namespace": namespace, | |
| "type": "session", | |
| "records_retrieved": len(memories), | |
| } | |
| ] | |
| } | |
| return json.dumps( | |
| {"persona_id": persona_id, "session_id": session_id, "memories": memories, "trace": trace}, | |
| indent=2, | |
| default=str, | |
| ) | |
| def write_session_memory(persona_id: str, session_id: str, content: str) -> str: | |
| """Write a conversation turn to session memory for multi-turn context. | |
| Args: | |
| persona_id: The user ID. | |
| session_id: The current session identifier. | |
| content: The conversation turn content to store (query + results summary). | |
| """ | |
| container_id = _get_container_id() | |
| namespace = {"user_id": persona_id, "session_id": session_id} | |
| result = write_memory( | |
| container_id=container_id, | |
| namespace=namespace, | |
| content=content, | |
| ) | |
| trace = { | |
| "memory_writes": [ | |
| { | |
| "namespace": namespace, | |
| "type": "session", | |
| "action": "write_session_turn", | |
| } | |
| ] | |
| } | |
| return json.dumps({"status": "written", "trace": trace}, indent=2, default=str) | |
| def write_working_memory(persona_id: str, session_id: str, enriched_query: str) -> str: | |
| """Write the enriched query decomposition to working memory for the Ranking Agent to consume. | |
| This is the inter-agent communication channel. | |
| Args: | |
| persona_id: The user ID. | |
| session_id: The current session identifier. | |
| enriched_query: JSON string containing the enriched, decomposed query with all inferred attributes. | |
| """ | |
| container_id = _get_container_id() | |
| namespace = {"user_id": persona_id, "session_id": session_id} | |
| result = write_memory( | |
| container_id=container_id, | |
| namespace=namespace, | |
| content=enriched_query, | |
| infer=False, | |
| ) | |
| trace = { | |
| "memory_writes": [ | |
| { | |
| "namespace": namespace, | |
| "type": "working-memory", | |
| "action": "write_enriched_query", | |
| } | |
| ] | |
| } | |
| return json.dumps({"status": "written", "trace": trace}, indent=2, default=str) | |
| def write_history(persona_id: str, session_id: str, content: str) -> str: | |
| """Append a completed interaction to the user's history for long-term pattern learning. | |
| Args: | |
| persona_id: The user ID. | |
| session_id: The session identifier. | |
| content: Summary of the completed interaction (query, enrichment, results, outcome). | |
| """ | |
| container_id = _get_container_id() | |
| namespace = {"user_id": persona_id, "session_id": session_id} | |
| result = write_memory( | |
| container_id=container_id, | |
| namespace=namespace, | |
| content=content, | |
| ) | |
| trace = { | |
| "memory_writes": [ | |
| { | |
| "namespace": namespace, | |
| "type": "history", | |
| "action": "append_interaction", | |
| } | |
| ] | |
| } | |
| return json.dumps({"status": "written", "trace": trace}, indent=2, default=str) | |
| def search_product_catalog(query: str, category: Optional[str] = None, max_price: Optional[float] = None, gender_affinity: Optional[str] = None, size: int = 10) -> str: | |
| """Search the product catalog using server-side neural query (OpenSearch does the embedding). | |
| Style/product-type matching is handled semantically by the vector search β do NOT use term filters | |
| for fields with semantic meaning. | |
| Only truly categorical/numeric fields are used as hard filters: | |
| - category: broad taxonomy (5 values) β safe for exact filtering | |
| - max_price: numeric range β safe for range filtering | |
| - gender_affinity: hard filter ("M" or "F") to ensure persona-appropriate results | |
| Args: | |
| query: Natural language search query (e.g., "navy leather boots"). | |
| category: Optional category filter (apparel, footwear, accessories, jewelry, electronics). | |
| max_price: Optional maximum price filter. | |
| gender_affinity: Optional gender affinity filter ("M" or "F"). Hard-filters to persona-appropriate products. | |
| size: Max number of results to return. | |
| """ | |
| import os | |
| from search_personalization.data_loader import get_opensearch_client | |
| client = get_opensearch_client() | |
| model_id = os.getenv('OPENSEARCH_MODEL_ID', 'default_model_id') | |
| # Hard filters: only for truly categorical/numeric fields | |
| filter_clauses = [] | |
| if category: | |
| filter_clauses.append({"term": {"category": category}}) | |
| if max_price: | |
| filter_clauses.append({"range": {"price": {"lte": max_price}}}) | |
| if gender_affinity: | |
| filter_clauses.append({"term": {"gender_affinity": gender_affinity.upper()}}) | |
| # Use server-side neural query β OpenSearch calls the embedding model via its ML connector | |
| neural_query: dict = { | |
| "neural": { | |
| "product_description_vector": { | |
| "query_text": query, | |
| "model_id": model_id, | |
| "k": size, | |
| } | |
| } | |
| } | |
| if filter_clauses: | |
| # Use neural query's built-in filter (pre-filter before kNN neighbor selection) | |
| neural_query["neural"]["product_description_vector"]["filter"] = {"bool": {"filter": filter_clauses}} | |
| query_body = { | |
| "size": size, | |
| "query": neural_query, | |
| } | |
| resp = client.search(index="products", body=query_body) | |
| hits = resp["hits"]["hits"] | |
| results = [] | |
| for h in hits: | |
| src = h["_source"] | |
| results.append({ | |
| "id": src.get("id"), | |
| "name": src.get("name"), | |
| "category": src.get("category"), | |
| "style": src.get("style"), | |
| "price": src.get("price"), | |
| "description": src.get("description", "")[:150], | |
| "score": h.get("_score"), | |
| }) | |
| return json.dumps({"total_hits": resp["hits"]["total"]["value"], "results": results, "_opensearch_query": query_body}, indent=2, default=str) | |
| def rerank_results(query: str, documents: str, top_n: int = 10) -> str: | |
| """Rerank search results using Cohere Rerank 3.5 via Bedrock. | |
| Call this AFTER search_product_catalog to reorder results by relevance | |
| to the enriched query. Pass the enriched query (with user preferences baked in) | |
| as the query parameter for personalized reranking. | |
| Args: | |
| query: The enriched query string to rerank against (e.g., "spacious tan leather backpack, neutral colors, business-casual, $75-$175"). | |
| documents: JSON string β either the full search_product_catalog output or just the results array. | |
| top_n: Number of top results to return after reranking. | |
| """ | |
| import hashlib | |
| import boto3 | |
| # Cache lookup β keyed by query + document content | |
| cache_key = hashlib.sha256((query + documents if isinstance(documents, str) else query + json.dumps(documents, sort_keys=True)).encode()).hexdigest() | |
| if cache_key in _rerank_cache: | |
| return _rerank_cache[cache_key] | |
| client = boto3.client("bedrock-runtime", region_name=os.getenv("AWS_REGION", "us-east-1")) | |
| # Robustly parse documents β handle both full output and just the array | |
| if isinstance(documents, str): | |
| parsed = json.loads(documents) | |
| else: | |
| parsed = documents | |
| if isinstance(parsed, dict) and "results" in parsed: | |
| docs = parsed["results"] | |
| elif isinstance(parsed, list): | |
| docs = parsed | |
| else: | |
| docs = parsed.get("reranked_results", []) if isinstance(parsed, dict) else [] | |
| if not docs: | |
| return json.dumps({"error": "No documents to rerank"}) | |
| # Build document strings for Cohere Rerank API (must be plain strings) | |
| doc_texts = [] | |
| for d in docs: | |
| text = f"{d.get('name', '')}. {d.get('description', '')}. Style: {d.get('style', '')}. Category: {d.get('category', '')}. Price: ${d.get('price', '')}" | |
| doc_texts.append(text) | |
| body = json.dumps({ | |
| "query": query, | |
| "documents": doc_texts, | |
| "top_n": min(top_n, len(doc_texts)), | |
| "api_version": 2, | |
| }) | |
| response = client.invoke_model( | |
| modelId="cohere.rerank-v3-5:0", | |
| body=body, | |
| contentType="application/json", | |
| accept="application/json", | |
| ) | |
| result = json.loads(response["body"].read()) | |
| reranked = [] | |
| for item in result.get("results", []): | |
| idx = item["index"] | |
| if idx < len(docs): | |
| original = docs[idx].copy() | |
| original["rerank_score"] = item["relevance_score"] | |
| reranked.append(original) | |
| result_json = json.dumps({"reranked_results": reranked}, indent=2, default=str) | |
| _rerank_cache[cache_key] = result_json | |
| return result_json | |