import httpx import logging import polars as pl import numpy as np from fastapi.applications import FastAPI from typing import Literal, Optional from sentence_transformers import util from sklearn.metrics.pairwise import cosine_similarity logger = logging.getLogger(__name__) from src.config import config from src.utils.logging import context_logger async def encode(texts: list[str], mode: Literal["item", "scale"] = "item"): async with httpx.AsyncClient() as client: response = await client.post( "http://localhost:8001/encode", json={"texts": texts, "mode": mode}, timeout=30.0 ) response.raise_for_status() result = np.array(response.json()['embeddings']) return result def align_embeddings(item_embeddings, keying): item_embeddings_positive = item_embeddings[[x == "positive" for x in keying]] item_embeddings_negative = item_embeddings[[x == "negative" for x in keying]] if item_embeddings_positive.size == 0 or item_embeddings_negative.size == 0: return { 'item_centroid_positive': np.nan, 'item_centroid_negative': np.nan, 'item_embeddings_aligned': np.nan, 'item_centroid_aligned': np.nan } item_centroid_positive = item_embeddings_positive.mean(axis=0) item_centroid_negative = item_embeddings_negative.mean(axis=0) cosine_similarities = util.cos_sim(item_embeddings, item_centroid_positive).numpy().squeeze() synthetic_is_negative = cosine_similarities < 0 polarity_axis = item_centroid_positive - item_centroid_negative axis_magnitude = np.sqrt(np.sum(polarity_axis**2)) if not np.isfinite(axis_magnitude) or axis_magnitude <= 0 or not any(synthetic_is_negative): return { 'item_centroid_positive': np.nan, 'item_centroid_negative': np.nan, 'item_embeddings_aligned': np.nan, 'item_centroid_aligned': np.nan } polarity_unit_vector = polarity_axis / axis_magnitude reflection_plane_center = (item_centroid_positive + item_centroid_negative) / 2 signed_distances_to_plane = np.dot( item_embeddings - reflection_plane_center, polarity_unit_vector ) items_to_align = np.array([x == "negative" for x in keying]) & synthetic_is_negative reflection_distances = np.where(items_to_align, signed_distances_to_plane, 0) item_embeddings_aligned = item_embeddings - 2 * np.outer( reflection_distances, polarity_unit_vector ) item_centroid_aligned = item_embeddings_aligned.mean(axis=0) return { 'item_centroid_positive': item_centroid_positive, 'item_centroid_negative': item_centroid_negative, 'item_embeddings_aligned': item_embeddings_aligned, 'item_centroid_aligned': item_centroid_aligned } async def semantic_item_search(queries: list[dict], app: FastAPI) -> np.ndarray: query_items = [q['text'] for q in queries] query_keys = [q['reversed'] for q in queries] with context_logger(f"Sending encoding requests for {len(query_items)} queries"): query_embeddings = await encode(texts=query_items, mode="item") with context_logger(f"Aligning item embeddings based on keying"): keying = ["negative" if x else "positive" for x in query_keys] query_embeddings_aligned = align_embeddings(query_embeddings, keying) query_centroid = query_embeddings_aligned['item_centroid_aligned'] if np.any(np.isnan(query_centroid)): logger.info(f"Query embedding alignment failed, calculating centroid without alignment") query_centroid = query_embeddings.mean(axis=0) with context_logger("Calculating cosine similarity"): similarities = cosine_similarity( X=app.state.data['item_centroids'], Y=query_centroid.reshape(1, -1) ).ravel() return similarities async def semantic_scale_search(queries: list[dict], app: FastAPI) -> np.ndarray: query = [q['text'] for q in queries] with context_logger(f"Sending encoding requests for {len(query)} queries."): query_embeddings = await encode(texts=query, mode="scale") query_embeddings = query_embeddings.squeeze() with context_logger("Calculating cosine similarity"): similarities = cosine_similarity( X=app.state.data['scale_centroids'], Y=query_embeddings.reshape(1, -1) ).ravel() return similarities async def compute_search_results(similarities: np.ndarray, app: FastAPI) -> pl.DataFrame: search_results = ( app.state.data['meta'].clone() .with_columns( pl.Series("similarity", similarities).round(3) ) .group_by("meta_doi") .agg([ pl.col('scale_name'), pl.col('is_instrument'), pl.col([ 'meta_instrument_name', 'warn_item_count_deviation', 'warn_scale_count_deviation', 'warn_item_text_deviation', 'warn_keying_correction' ]).first(), pl.col('similarity'), ]) .with_columns( pl.concat_list([ pl.when(pl.col('warn_item_count_deviation')) .then(pl.lit('ITEM_COUNT_DEVIATION')), pl.when(pl.col('warn_scale_count_deviation')) .then(pl.lit('SCALE_COUNT_DEVIATION')), pl.when(pl.col('warn_item_text_deviation')) .then(pl.lit('ITEM_TEXT_DEVIATION')), pl.when(pl.col('warn_keying_correction')) .then(pl.lit('KEYING_CORRECTION')), ]).list.drop_nulls().alias('warning_codes') ) .with_columns( pl.col('warning_codes').list.len().alias('warning_count'), max_similarity = pl.col("similarity").list.max(), max_abs_similarity = pl.col("similarity").list.max().abs(), ) .drop([ 'warn_item_count_deviation', 'warn_scale_count_deviation', 'warn_item_text_deviation', 'warn_keying_correction' ]) ) return search_results async def filter_search(df: pl.DataFrame, filter_string: str) -> pl.DataFrame: if filter_string: in_instrument_name = df['meta_instrument_name'].str.to_lowercase().str.contains(filter_string) in_scale_names = ( df['scale_name'] .list.drop_nulls() # Remove null values from each list .list.join(" ") # Join list elements with space separator .str.to_lowercase() .str.contains(filter_string) ) return df.filter(in_instrument_name | in_scale_names) return df async def refine_search( df: pl.DataFrame, sort_col: str, sort_descending: bool, page_index: int, page_size: int ) -> pl.DataFrame: sorted_result = df.sort(by=sort_col, descending=sort_descending) start_index = page_index * page_size end_index = start_index + page_size page_results = sorted_result[start_index:end_index] return page_results