prasadnu's picture
feat: add Search Personalization demo module
b4d5c9a
Raw
History Blame
11.7 kB
"""
Strands @tool functions for agents to read/write OpenSearch agentic memory.
These tools are shared across the Query Understanding and Ranking agents.
"""
import json
import os
from typing import Optional
from strands import tool
from search_personalization.agentic_memory.config import CONTAINER_ID
from search_personalization.agentic_memory.memory_client import read_memory, write_memory
# In-memory cache for rerank results (keyed by hash of query + documents)
_rerank_cache: dict[str, str] = {}
def _get_container_id() -> str:
"""Get the single memory container ID, raising if not configured."""
if not CONTAINER_ID:
raise ValueError(
"MEMORY_CONTAINER_ID not configured. Run setup or set the env var."
)
return CONTAINER_ID
@tool
def get_user_profile(persona_id: str) -> str:
"""Retrieve the user's long-term memory profile from OpenSearch agentic memory.
Returns extracted preferences including color preferences, style, size, price sensitivity,
and any explicit aversions learned from purchase history, returns, and reviews.
Args:
persona_id: The user ID (e.g., user1, user2).
"""
from search_personalization.agentic_memory.memory_client import _client
container_id = _get_container_id()
namespace = {"user_id": persona_id}
# Query long-term memory filtered to USER_PREFERENCE strategy only
client = _client()
body = {
"query": {"bool": {"must": [
{"term": {"namespace.user_id": persona_id}},
{"term": {"strategy_type": "USER_PREFERENCE"}},
]}},
"size": 20,
}
result = client.transport.perform_request(
"GET",
f"/_plugins/_ml/memory_containers/{container_id}/memories/long-term/_search",
body=body,
)
memories = result.get("hits", {}).get("hits", [])
trace = {
"memory_reads": [
{
"namespace": namespace,
"type": "long-term",
"strategy_filter": "USER_PREFERENCE",
"records_retrieved": len(memories),
}
]
}
return json.dumps(
{"persona_id": persona_id, "memories": memories, "trace": trace},
indent=2,
default=str,
)
@tool
def read_session_memory(persona_id: str, session_id: str) -> str:
"""Retrieve current session context for multi-turn query understanding.
Returns previous queries and interactions within this shopping session.
Args:
persona_id: The user ID (e.g., user1, user2).
session_id: The current session identifier.
"""
container_id = _get_container_id()
namespace = {"user_id": persona_id, "session_id": session_id}
result = read_memory(
container_id=container_id,
namespace=namespace,
memory_type="sessions",
)
memories = result.get("hits", {}).get("hits", [])
trace = {
"memory_reads": [
{
"namespace": namespace,
"type": "session",
"records_retrieved": len(memories),
}
]
}
return json.dumps(
{"persona_id": persona_id, "session_id": session_id, "memories": memories, "trace": trace},
indent=2,
default=str,
)
@tool
def write_session_memory(persona_id: str, session_id: str, content: str) -> str:
"""Write a conversation turn to session memory for multi-turn context.
Args:
persona_id: The user ID.
session_id: The current session identifier.
content: The conversation turn content to store (query + results summary).
"""
container_id = _get_container_id()
namespace = {"user_id": persona_id, "session_id": session_id}
result = write_memory(
container_id=container_id,
namespace=namespace,
content=content,
)
trace = {
"memory_writes": [
{
"namespace": namespace,
"type": "session",
"action": "write_session_turn",
}
]
}
return json.dumps({"status": "written", "trace": trace}, indent=2, default=str)
@tool
def write_working_memory(persona_id: str, session_id: str, enriched_query: str) -> str:
"""Write the enriched query decomposition to working memory for the Ranking Agent to consume.
This is the inter-agent communication channel.
Args:
persona_id: The user ID.
session_id: The current session identifier.
enriched_query: JSON string containing the enriched, decomposed query with all inferred attributes.
"""
container_id = _get_container_id()
namespace = {"user_id": persona_id, "session_id": session_id}
result = write_memory(
container_id=container_id,
namespace=namespace,
content=enriched_query,
infer=False,
)
trace = {
"memory_writes": [
{
"namespace": namespace,
"type": "working-memory",
"action": "write_enriched_query",
}
]
}
return json.dumps({"status": "written", "trace": trace}, indent=2, default=str)
@tool
def write_history(persona_id: str, session_id: str, content: str) -> str:
"""Append a completed interaction to the user's history for long-term pattern learning.
Args:
persona_id: The user ID.
session_id: The session identifier.
content: Summary of the completed interaction (query, enrichment, results, outcome).
"""
container_id = _get_container_id()
namespace = {"user_id": persona_id, "session_id": session_id}
result = write_memory(
container_id=container_id,
namespace=namespace,
content=content,
)
trace = {
"memory_writes": [
{
"namespace": namespace,
"type": "history",
"action": "append_interaction",
}
]
}
return json.dumps({"status": "written", "trace": trace}, indent=2, default=str)
@tool
def search_product_catalog(query: str, category: Optional[str] = None, max_price: Optional[float] = None, gender_affinity: Optional[str] = None, size: int = 10) -> str:
"""Search the product catalog using server-side neural query (OpenSearch does the embedding).
Style/product-type matching is handled semantically by the vector search β€” do NOT use term filters
for fields with semantic meaning.
Only truly categorical/numeric fields are used as hard filters:
- category: broad taxonomy (5 values) β€” safe for exact filtering
- max_price: numeric range β€” safe for range filtering
- gender_affinity: hard filter ("M" or "F") to ensure persona-appropriate results
Args:
query: Natural language search query (e.g., "navy leather boots").
category: Optional category filter (apparel, footwear, accessories, jewelry, electronics).
max_price: Optional maximum price filter.
gender_affinity: Optional gender affinity filter ("M" or "F"). Hard-filters to persona-appropriate products.
size: Max number of results to return.
"""
import os
from search_personalization.data_loader import get_opensearch_client
client = get_opensearch_client()
model_id = os.getenv('OPENSEARCH_MODEL_ID', 'default_model_id')
# Hard filters: only for truly categorical/numeric fields
filter_clauses = []
if category:
filter_clauses.append({"term": {"category": category}})
if max_price:
filter_clauses.append({"range": {"price": {"lte": max_price}}})
if gender_affinity:
filter_clauses.append({"term": {"gender_affinity": gender_affinity.upper()}})
# Use server-side neural query β€” OpenSearch calls the embedding model via its ML connector
neural_query: dict = {
"neural": {
"product_description_vector": {
"query_text": query,
"model_id": model_id,
"k": size,
}
}
}
if filter_clauses:
# Use neural query's built-in filter (pre-filter before kNN neighbor selection)
neural_query["neural"]["product_description_vector"]["filter"] = {"bool": {"filter": filter_clauses}}
query_body = {
"size": size,
"query": neural_query,
}
resp = client.search(index="products", body=query_body)
hits = resp["hits"]["hits"]
results = []
for h in hits:
src = h["_source"]
results.append({
"id": src.get("id"),
"name": src.get("name"),
"category": src.get("category"),
"style": src.get("style"),
"price": src.get("price"),
"description": src.get("description", "")[:150],
"score": h.get("_score"),
})
return json.dumps({"total_hits": resp["hits"]["total"]["value"], "results": results, "_opensearch_query": query_body}, indent=2, default=str)
@tool
def rerank_results(query: str, documents: str, top_n: int = 10) -> str:
"""Rerank search results using Cohere Rerank 3.5 via Bedrock.
Call this AFTER search_product_catalog to reorder results by relevance
to the enriched query. Pass the enriched query (with user preferences baked in)
as the query parameter for personalized reranking.
Args:
query: The enriched query string to rerank against (e.g., "spacious tan leather backpack, neutral colors, business-casual, $75-$175").
documents: JSON string β€” either the full search_product_catalog output or just the results array.
top_n: Number of top results to return after reranking.
"""
import hashlib
import boto3
# Cache lookup β€” keyed by query + document content
cache_key = hashlib.sha256((query + documents if isinstance(documents, str) else query + json.dumps(documents, sort_keys=True)).encode()).hexdigest()
if cache_key in _rerank_cache:
return _rerank_cache[cache_key]
client = boto3.client("bedrock-runtime", region_name=os.getenv("AWS_REGION", "us-east-1"))
# Robustly parse documents β€” handle both full output and just the array
if isinstance(documents, str):
parsed = json.loads(documents)
else:
parsed = documents
if isinstance(parsed, dict) and "results" in parsed:
docs = parsed["results"]
elif isinstance(parsed, list):
docs = parsed
else:
docs = parsed.get("reranked_results", []) if isinstance(parsed, dict) else []
if not docs:
return json.dumps({"error": "No documents to rerank"})
# Build document strings for Cohere Rerank API (must be plain strings)
doc_texts = []
for d in docs:
text = f"{d.get('name', '')}. {d.get('description', '')}. Style: {d.get('style', '')}. Category: {d.get('category', '')}. Price: ${d.get('price', '')}"
doc_texts.append(text)
body = json.dumps({
"query": query,
"documents": doc_texts,
"top_n": min(top_n, len(doc_texts)),
"api_version": 2,
})
response = client.invoke_model(
modelId="cohere.rerank-v3-5:0",
body=body,
contentType="application/json",
accept="application/json",
)
result = json.loads(response["body"].read())
reranked = []
for item in result.get("results", []):
idx = item["index"]
if idx < len(docs):
original = docs[idx].copy()
original["rerank_score"] = item["relevance_score"]
reranked.append(original)
result_json = json.dumps({"reranked_results": reranked}, indent=2, default=str)
_rerank_cache[cache_key] = result_json
return result_json