|
|
import io |
|
|
import csv |
|
|
import logging |
|
|
import polars as pl |
|
|
from fastapi.responses import StreamingResponse |
|
|
from typing import Annotated |
|
|
from fastapi import APIRouter, HTTPException, Request, Form |
|
|
from fastapi.templating import Jinja2Templates |
|
|
|
|
|
from src.config import config |
|
|
from src.schemas.search_request import SearchRequest |
|
|
from src.services.usage_tracking_service import send_tracking_event_to_db |
|
|
from src.services.search_service import semantic_item_search, semantic_scale_search, compute_search_results, refine_search, filter_search |
|
|
from src.utils.logging import context_logger |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
templates = Jinja2Templates(directory="./src/templates") |
|
|
|
|
|
@router.post("/") |
|
|
async def search( |
|
|
request: Request, |
|
|
search_data: Annotated[SearchRequest, Form()], |
|
|
): |
|
|
logger.info(f"Processing search request: \n{search_data.model_dump_json()}") |
|
|
|
|
|
if config.tracking_db.enabled: |
|
|
with context_logger(f"Sending search usage data to tracking DB"): |
|
|
await send_tracking_event_to_db( |
|
|
search_data=search_data |
|
|
) |
|
|
else: |
|
|
logger.warning("Tracking disabled in config / ENV, skipping DB write!") |
|
|
|
|
|
if search_data.mode == 'items': |
|
|
similarities = await semantic_item_search( |
|
|
queries=search_data.items, |
|
|
app=request.app |
|
|
) |
|
|
elif search_data.mode == 'scale': |
|
|
similarities = await semantic_scale_search( |
|
|
queries=search_data.scale, |
|
|
app=request.app |
|
|
) |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Invalid search mode") |
|
|
|
|
|
raw_search_results = await compute_search_results( |
|
|
similarities=similarities, |
|
|
app=request.app |
|
|
) |
|
|
|
|
|
filtered_results = await filter_search( |
|
|
df=raw_search_results, |
|
|
filter_string=search_data.filter_string, |
|
|
) |
|
|
|
|
|
search_results = await refine_search( |
|
|
df=filtered_results, |
|
|
sort_col=search_data.sort_col, |
|
|
sort_descending=search_data.sort_descending, |
|
|
page_index=search_data.page_index, |
|
|
page_size=search_data.page_size, |
|
|
) |
|
|
|
|
|
return templates.TemplateResponse( |
|
|
name="pages/search/partials/_search_results.jinja", |
|
|
context={ |
|
|
"request": request, |
|
|
"search_mode": search_data.mode, |
|
|
"search_result": search_results.to_dicts(), |
|
|
"filter_string": search_data.filter_string, |
|
|
"sort_col": search_data.sort_col, |
|
|
"sort_descending": search_data.sort_descending, |
|
|
"page_index": search_data.page_index, |
|
|
"page_size": search_data.page_size, |
|
|
"results_count": filtered_results.height, |
|
|
} |
|
|
) |
|
|
|
|
|
@router.post("/export") |
|
|
async def export_search_results( |
|
|
request: Request, |
|
|
search_data: Annotated[SearchRequest, Form()], |
|
|
): |
|
|
|
|
|
if search_data.mode == 'items': |
|
|
similarities = await semantic_item_search( |
|
|
queries=search_data.items, |
|
|
app=request.app |
|
|
) |
|
|
elif search_data.mode == 'scale': |
|
|
similarities = await semantic_scale_search( |
|
|
queries=search_data.scale, |
|
|
app=request.app |
|
|
) |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Invalid search mode") |
|
|
|
|
|
raw_search_results = await compute_search_results( |
|
|
similarities=similarities, |
|
|
app=request.app |
|
|
) |
|
|
|
|
|
filtered_results = await filter_search( |
|
|
df=raw_search_results, |
|
|
filter_string=search_data.filter_string, |
|
|
) |
|
|
|
|
|
sorted_results = filtered_results.sort( |
|
|
by=search_data.sort_col, |
|
|
descending=search_data.sort_descending |
|
|
) |
|
|
|
|
|
output = io.StringIO() |
|
|
writer = csv.writer(output) |
|
|
|
|
|
export_df = sorted_results.select([ |
|
|
"meta_doi", |
|
|
"meta_instrument_name", |
|
|
"max_similarity", |
|
|
pl.col("scale_name").list.join("; ").alias("scale_names"), |
|
|
pl.col("similarity").list.eval(pl.element().round(3).cast(pl.Utf8)).list.join("; ").alias("similarities"), |
|
|
pl.col("warning_codes").list.join("; ").alias("warnings"), |
|
|
]) |
|
|
|
|
|
|
|
|
writer.writerow(export_df.columns) |
|
|
writer.writerows(export_df.iter_rows()) |
|
|
|
|
|
output.seek(0) |
|
|
|
|
|
return StreamingResponse( |
|
|
iter([output.getvalue()]), |
|
|
media_type="text/csv", |
|
|
headers={"Content-Disposition": "attachment; filename=synthnet_results.csv"} |
|
|
) |