|
|
import httpx |
|
|
import logging |
|
|
import polars as pl |
|
|
import numpy as np |
|
|
from fastapi.applications import FastAPI |
|
|
from typing import Literal, Optional |
|
|
from sentence_transformers import util |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
from src.config import config |
|
|
from src.utils.logging import context_logger |
|
|
|
|
|
async def encode(texts: list[str], mode: Literal["item", "scale"] = "item"): |
|
|
|
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post( |
|
|
"http://localhost:8001/encode", |
|
|
json={"texts": texts, "mode": mode}, |
|
|
timeout=30.0 |
|
|
) |
|
|
response.raise_for_status() |
|
|
|
|
|
result = np.array(response.json()['embeddings']) |
|
|
|
|
|
return result |
|
|
|
|
|
def align_embeddings(item_embeddings, keying): |
|
|
|
|
|
item_embeddings_positive = item_embeddings[[x == "positive" for x in keying]] |
|
|
item_embeddings_negative = item_embeddings[[x == "negative" for x in keying]] |
|
|
|
|
|
if item_embeddings_positive.size == 0 or item_embeddings_negative.size == 0: |
|
|
return { |
|
|
'item_centroid_positive': np.nan, |
|
|
'item_centroid_negative': np.nan, |
|
|
'item_embeddings_aligned': np.nan, |
|
|
'item_centroid_aligned': np.nan |
|
|
} |
|
|
|
|
|
item_centroid_positive = item_embeddings_positive.mean(axis=0) |
|
|
item_centroid_negative = item_embeddings_negative.mean(axis=0) |
|
|
|
|
|
cosine_similarities = util.cos_sim(item_embeddings, item_centroid_positive).numpy().squeeze() |
|
|
synthetic_is_negative = cosine_similarities < 0 |
|
|
|
|
|
polarity_axis = item_centroid_positive - item_centroid_negative |
|
|
axis_magnitude = np.sqrt(np.sum(polarity_axis**2)) |
|
|
|
|
|
if not np.isfinite(axis_magnitude) or axis_magnitude <= 0 or not any(synthetic_is_negative): |
|
|
return { |
|
|
'item_centroid_positive': np.nan, |
|
|
'item_centroid_negative': np.nan, |
|
|
'item_embeddings_aligned': np.nan, |
|
|
'item_centroid_aligned': np.nan |
|
|
} |
|
|
|
|
|
polarity_unit_vector = polarity_axis / axis_magnitude |
|
|
reflection_plane_center = (item_centroid_positive + item_centroid_negative) / 2 |
|
|
|
|
|
signed_distances_to_plane = np.dot( |
|
|
item_embeddings - reflection_plane_center, |
|
|
polarity_unit_vector |
|
|
) |
|
|
|
|
|
items_to_align = np.array([x == "negative" for x in keying]) & synthetic_is_negative |
|
|
|
|
|
reflection_distances = np.where(items_to_align, signed_distances_to_plane, 0) |
|
|
|
|
|
item_embeddings_aligned = item_embeddings - 2 * np.outer( |
|
|
reflection_distances, |
|
|
polarity_unit_vector |
|
|
) |
|
|
|
|
|
item_centroid_aligned = item_embeddings_aligned.mean(axis=0) |
|
|
|
|
|
return { |
|
|
'item_centroid_positive': item_centroid_positive, |
|
|
'item_centroid_negative': item_centroid_negative, |
|
|
'item_embeddings_aligned': item_embeddings_aligned, |
|
|
'item_centroid_aligned': item_centroid_aligned |
|
|
} |
|
|
|
|
|
async def semantic_item_search(queries: list[dict], app: FastAPI) -> np.ndarray: |
|
|
|
|
|
query_items = [q['text'] for q in queries] |
|
|
query_keys = [q['reversed'] for q in queries] |
|
|
|
|
|
with context_logger(f"Sending encoding requests for {len(query_items)} queries"): |
|
|
query_embeddings = await encode(texts=query_items, mode="item") |
|
|
|
|
|
with context_logger(f"Aligning item embeddings based on keying"): |
|
|
|
|
|
keying = ["negative" if x else "positive" for x in query_keys] |
|
|
query_embeddings_aligned = align_embeddings(query_embeddings, keying) |
|
|
query_centroid = query_embeddings_aligned['item_centroid_aligned'] |
|
|
|
|
|
if np.any(np.isnan(query_centroid)): |
|
|
logger.info(f"Query embedding alignment failed, calculating centroid without alignment") |
|
|
query_centroid = query_embeddings.mean(axis=0) |
|
|
|
|
|
with context_logger("Calculating cosine similarity"): |
|
|
similarities = cosine_similarity( |
|
|
X=app.state.data['item_centroids'], |
|
|
Y=query_centroid.reshape(1, -1) |
|
|
).ravel() |
|
|
|
|
|
return similarities |
|
|
|
|
|
async def semantic_scale_search(queries: list[dict], app: FastAPI) -> np.ndarray: |
|
|
|
|
|
query = [q['text'] for q in queries] |
|
|
|
|
|
with context_logger(f"Sending encoding requests for {len(query)} queries."): |
|
|
query_embeddings = await encode(texts=query, mode="scale") |
|
|
query_embeddings = query_embeddings.squeeze() |
|
|
|
|
|
with context_logger("Calculating cosine similarity"): |
|
|
similarities = cosine_similarity( |
|
|
X=app.state.data['scale_centroids'], |
|
|
Y=query_embeddings.reshape(1, -1) |
|
|
).ravel() |
|
|
|
|
|
return similarities |
|
|
|
|
|
async def compute_search_results(similarities: np.ndarray, app: FastAPI) -> pl.DataFrame: |
|
|
|
|
|
search_results = ( |
|
|
app.state.data['meta'].clone() |
|
|
.with_columns( |
|
|
pl.Series("similarity", similarities).round(3) |
|
|
) |
|
|
.group_by("meta_doi") |
|
|
.agg([ |
|
|
pl.col('scale_name'), |
|
|
pl.col('is_instrument'), |
|
|
pl.col([ |
|
|
'meta_instrument_name', |
|
|
'warn_item_count_deviation', |
|
|
'warn_scale_count_deviation', |
|
|
'warn_item_text_deviation', |
|
|
'warn_keying_correction' |
|
|
]).first(), |
|
|
pl.col('similarity'), |
|
|
]) |
|
|
.with_columns( |
|
|
pl.concat_list([ |
|
|
pl.when(pl.col('warn_item_count_deviation')) |
|
|
.then(pl.lit('ITEM_COUNT_DEVIATION')), |
|
|
pl.when(pl.col('warn_scale_count_deviation')) |
|
|
.then(pl.lit('SCALE_COUNT_DEVIATION')), |
|
|
pl.when(pl.col('warn_item_text_deviation')) |
|
|
.then(pl.lit('ITEM_TEXT_DEVIATION')), |
|
|
pl.when(pl.col('warn_keying_correction')) |
|
|
.then(pl.lit('KEYING_CORRECTION')), |
|
|
]).list.drop_nulls().alias('warning_codes') |
|
|
) |
|
|
.with_columns( |
|
|
pl.col('warning_codes').list.len().alias('warning_count'), |
|
|
max_similarity = pl.col("similarity").list.max(), |
|
|
max_abs_similarity = pl.col("similarity").list.max().abs(), |
|
|
) |
|
|
.drop([ |
|
|
'warn_item_count_deviation', |
|
|
'warn_scale_count_deviation', |
|
|
'warn_item_text_deviation', |
|
|
'warn_keying_correction' |
|
|
]) |
|
|
) |
|
|
|
|
|
return search_results |
|
|
|
|
|
async def filter_search(df: pl.DataFrame, filter_string: str) -> pl.DataFrame: |
|
|
|
|
|
if filter_string: |
|
|
in_instrument_name = df['meta_instrument_name'].str.to_lowercase().str.contains(filter_string) |
|
|
in_scale_names = ( |
|
|
df['scale_name'] |
|
|
.list.drop_nulls() |
|
|
.list.join(" ") |
|
|
.str.to_lowercase() |
|
|
.str.contains(filter_string) |
|
|
) |
|
|
return df.filter(in_instrument_name | in_scale_names) |
|
|
return df |
|
|
|
|
|
async def refine_search( |
|
|
df: pl.DataFrame, |
|
|
sort_col: str, |
|
|
sort_descending: bool, |
|
|
page_index: int, |
|
|
page_size: int |
|
|
) -> pl.DataFrame: |
|
|
|
|
|
sorted_result = df.sort(by=sort_col, descending=sort_descending) |
|
|
|
|
|
start_index = page_index * page_size |
|
|
end_index = start_index + page_size |
|
|
page_results = sorted_result[start_index:end_index] |
|
|
|
|
|
return page_results |