File size: 4,413 Bytes
6ca4b94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import io
import csv
import logging
import polars as pl
from fastapi.responses import StreamingResponse
from typing import Annotated
from fastapi import APIRouter, HTTPException, Request, Form
from fastapi.templating import Jinja2Templates
from src.config import config
from src.schemas.search_request import SearchRequest
from src.services.usage_tracking_service import send_tracking_event_to_db
from src.services.search_service import semantic_item_search, semantic_scale_search, compute_search_results, refine_search, filter_search
from src.utils.logging import context_logger
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="./src/templates")
@router.post("/")
async def search(
request: Request,
search_data: Annotated[SearchRequest, Form()],
):
logger.info(f"Processing search request: \n{search_data.model_dump_json()}")
if config.tracking_db.enabled:
with context_logger(f"Sending search usage data to tracking DB"):
await send_tracking_event_to_db(
search_data=search_data
)
else:
logger.warning("Tracking disabled in config / ENV, skipping DB write!")
if search_data.mode == 'items':
similarities = await semantic_item_search(
queries=search_data.items,
app=request.app
)
elif search_data.mode == 'scale':
similarities = await semantic_scale_search(
queries=search_data.scale,
app=request.app
)
else:
raise HTTPException(status_code=400, detail="Invalid search mode")
raw_search_results = await compute_search_results(
similarities=similarities,
app=request.app
)
filtered_results = await filter_search(
df=raw_search_results,
filter_string=search_data.filter_string,
)
search_results = await refine_search(
df=filtered_results,
sort_col=search_data.sort_col,
sort_descending=search_data.sort_descending,
page_index=search_data.page_index,
page_size=search_data.page_size,
)
return templates.TemplateResponse(
name="pages/search/partials/_search_results.jinja",
context={
"request": request,
"search_mode": search_data.mode,
"search_result": search_results.to_dicts(),
"filter_string": search_data.filter_string,
"sort_col": search_data.sort_col,
"sort_descending": search_data.sort_descending,
"page_index": search_data.page_index,
"page_size": search_data.page_size,
"results_count": filtered_results.height,
}
)
@router.post("/export")
async def export_search_results(
request: Request,
search_data: Annotated[SearchRequest, Form()],
):
if search_data.mode == 'items':
similarities = await semantic_item_search(
queries=search_data.items,
app=request.app
)
elif search_data.mode == 'scale':
similarities = await semantic_scale_search(
queries=search_data.scale,
app=request.app
)
else:
raise HTTPException(status_code=400, detail="Invalid search mode")
raw_search_results = await compute_search_results(
similarities=similarities,
app=request.app
)
filtered_results = await filter_search(
df=raw_search_results,
filter_string=search_data.filter_string,
)
sorted_results = filtered_results.sort(
by=search_data.sort_col,
descending=search_data.sort_descending
)
output = io.StringIO()
writer = csv.writer(output)
export_df = sorted_results.select([
"meta_doi",
"meta_instrument_name",
"max_similarity",
pl.col("scale_name").list.join("; ").alias("scale_names"),
pl.col("similarity").list.eval(pl.element().round(3).cast(pl.Utf8)).list.join("; ").alias("similarities"),
pl.col("warning_codes").list.join("; ").alias("warnings"),
])
# Write header and rows
writer.writerow(export_df.columns)
writer.writerows(export_df.iter_rows())
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=synthnet_results.csv"}
) |