synth-net / src /services /search_service.py
github-actions
Sync from GitHub (CI)
6ca4b94
import httpx
import logging
import polars as pl
import numpy as np
from fastapi.applications import FastAPI
from typing import Literal, Optional
from sentence_transformers import util
from sklearn.metrics.pairwise import cosine_similarity
logger = logging.getLogger(__name__)
from src.config import config
from src.utils.logging import context_logger
async def encode(texts: list[str], mode: Literal["item", "scale"] = "item"):
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8001/encode",
json={"texts": texts, "mode": mode},
timeout=30.0
)
response.raise_for_status()
result = np.array(response.json()['embeddings'])
return result
def align_embeddings(item_embeddings, keying):
item_embeddings_positive = item_embeddings[[x == "positive" for x in keying]]
item_embeddings_negative = item_embeddings[[x == "negative" for x in keying]]
if item_embeddings_positive.size == 0 or item_embeddings_negative.size == 0:
return {
'item_centroid_positive': np.nan,
'item_centroid_negative': np.nan,
'item_embeddings_aligned': np.nan,
'item_centroid_aligned': np.nan
}
item_centroid_positive = item_embeddings_positive.mean(axis=0)
item_centroid_negative = item_embeddings_negative.mean(axis=0)
cosine_similarities = util.cos_sim(item_embeddings, item_centroid_positive).numpy().squeeze()
synthetic_is_negative = cosine_similarities < 0
polarity_axis = item_centroid_positive - item_centroid_negative
axis_magnitude = np.sqrt(np.sum(polarity_axis**2))
if not np.isfinite(axis_magnitude) or axis_magnitude <= 0 or not any(synthetic_is_negative):
return {
'item_centroid_positive': np.nan,
'item_centroid_negative': np.nan,
'item_embeddings_aligned': np.nan,
'item_centroid_aligned': np.nan
}
polarity_unit_vector = polarity_axis / axis_magnitude
reflection_plane_center = (item_centroid_positive + item_centroid_negative) / 2
signed_distances_to_plane = np.dot(
item_embeddings - reflection_plane_center,
polarity_unit_vector
)
items_to_align = np.array([x == "negative" for x in keying]) & synthetic_is_negative
reflection_distances = np.where(items_to_align, signed_distances_to_plane, 0)
item_embeddings_aligned = item_embeddings - 2 * np.outer(
reflection_distances,
polarity_unit_vector
)
item_centroid_aligned = item_embeddings_aligned.mean(axis=0)
return {
'item_centroid_positive': item_centroid_positive,
'item_centroid_negative': item_centroid_negative,
'item_embeddings_aligned': item_embeddings_aligned,
'item_centroid_aligned': item_centroid_aligned
}
async def semantic_item_search(queries: list[dict], app: FastAPI) -> np.ndarray:
query_items = [q['text'] for q in queries]
query_keys = [q['reversed'] for q in queries]
with context_logger(f"Sending encoding requests for {len(query_items)} queries"):
query_embeddings = await encode(texts=query_items, mode="item")
with context_logger(f"Aligning item embeddings based on keying"):
keying = ["negative" if x else "positive" for x in query_keys]
query_embeddings_aligned = align_embeddings(query_embeddings, keying)
query_centroid = query_embeddings_aligned['item_centroid_aligned']
if np.any(np.isnan(query_centroid)):
logger.info(f"Query embedding alignment failed, calculating centroid without alignment")
query_centroid = query_embeddings.mean(axis=0)
with context_logger("Calculating cosine similarity"):
similarities = cosine_similarity(
X=app.state.data['item_centroids'],
Y=query_centroid.reshape(1, -1)
).ravel()
return similarities
async def semantic_scale_search(queries: list[dict], app: FastAPI) -> np.ndarray:
query = [q['text'] for q in queries]
with context_logger(f"Sending encoding requests for {len(query)} queries."):
query_embeddings = await encode(texts=query, mode="scale")
query_embeddings = query_embeddings.squeeze()
with context_logger("Calculating cosine similarity"):
similarities = cosine_similarity(
X=app.state.data['scale_centroids'],
Y=query_embeddings.reshape(1, -1)
).ravel()
return similarities
async def compute_search_results(similarities: np.ndarray, app: FastAPI) -> pl.DataFrame:
search_results = (
app.state.data['meta'].clone()
.with_columns(
pl.Series("similarity", similarities).round(3)
)
.group_by("meta_doi")
.agg([
pl.col('scale_name'),
pl.col('is_instrument'),
pl.col([
'meta_instrument_name',
'warn_item_count_deviation',
'warn_scale_count_deviation',
'warn_item_text_deviation',
'warn_keying_correction'
]).first(),
pl.col('similarity'),
])
.with_columns(
pl.concat_list([
pl.when(pl.col('warn_item_count_deviation'))
.then(pl.lit('ITEM_COUNT_DEVIATION')),
pl.when(pl.col('warn_scale_count_deviation'))
.then(pl.lit('SCALE_COUNT_DEVIATION')),
pl.when(pl.col('warn_item_text_deviation'))
.then(pl.lit('ITEM_TEXT_DEVIATION')),
pl.when(pl.col('warn_keying_correction'))
.then(pl.lit('KEYING_CORRECTION')),
]).list.drop_nulls().alias('warning_codes')
)
.with_columns(
pl.col('warning_codes').list.len().alias('warning_count'),
max_similarity = pl.col("similarity").list.max(),
max_abs_similarity = pl.col("similarity").list.max().abs(),
)
.drop([
'warn_item_count_deviation',
'warn_scale_count_deviation',
'warn_item_text_deviation',
'warn_keying_correction'
])
)
return search_results
async def filter_search(df: pl.DataFrame, filter_string: str) -> pl.DataFrame:
if filter_string:
in_instrument_name = df['meta_instrument_name'].str.to_lowercase().str.contains(filter_string)
in_scale_names = (
df['scale_name']
.list.drop_nulls() # Remove null values from each list
.list.join(" ") # Join list elements with space separator
.str.to_lowercase()
.str.contains(filter_string)
)
return df.filter(in_instrument_name | in_scale_names)
return df
async def refine_search(
df: pl.DataFrame,
sort_col: str,
sort_descending: bool,
page_index: int,
page_size: int
) -> pl.DataFrame:
sorted_result = df.sort(by=sort_col, descending=sort_descending)
start_index = page_index * page_size
end_index = start_index + page_size
page_results = sorted_result[start_index:end_index]
return page_results