github-actions
Sync from GitHub (CI)
6ca4b94
import io
import csv
import logging
import polars as pl
from fastapi.responses import StreamingResponse
from typing import Annotated
from fastapi import APIRouter, HTTPException, Request, Form
from fastapi.templating import Jinja2Templates
from src.config import config
from src.schemas.search_request import SearchRequest
from src.services.usage_tracking_service import send_tracking_event_to_db
from src.services.search_service import semantic_item_search, semantic_scale_search, compute_search_results, refine_search, filter_search
from src.utils.logging import context_logger
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="./src/templates")
@router.post("/")
async def search(
request: Request,
search_data: Annotated[SearchRequest, Form()],
):
logger.info(f"Processing search request: \n{search_data.model_dump_json()}")
if config.tracking_db.enabled:
with context_logger(f"Sending search usage data to tracking DB"):
await send_tracking_event_to_db(
search_data=search_data
)
else:
logger.warning("Tracking disabled in config / ENV, skipping DB write!")
if search_data.mode == 'items':
similarities = await semantic_item_search(
queries=search_data.items,
app=request.app
)
elif search_data.mode == 'scale':
similarities = await semantic_scale_search(
queries=search_data.scale,
app=request.app
)
else:
raise HTTPException(status_code=400, detail="Invalid search mode")
raw_search_results = await compute_search_results(
similarities=similarities,
app=request.app
)
filtered_results = await filter_search(
df=raw_search_results,
filter_string=search_data.filter_string,
)
search_results = await refine_search(
df=filtered_results,
sort_col=search_data.sort_col,
sort_descending=search_data.sort_descending,
page_index=search_data.page_index,
page_size=search_data.page_size,
)
return templates.TemplateResponse(
name="pages/search/partials/_search_results.jinja",
context={
"request": request,
"search_mode": search_data.mode,
"search_result": search_results.to_dicts(),
"filter_string": search_data.filter_string,
"sort_col": search_data.sort_col,
"sort_descending": search_data.sort_descending,
"page_index": search_data.page_index,
"page_size": search_data.page_size,
"results_count": filtered_results.height,
}
)
@router.post("/export")
async def export_search_results(
request: Request,
search_data: Annotated[SearchRequest, Form()],
):
if search_data.mode == 'items':
similarities = await semantic_item_search(
queries=search_data.items,
app=request.app
)
elif search_data.mode == 'scale':
similarities = await semantic_scale_search(
queries=search_data.scale,
app=request.app
)
else:
raise HTTPException(status_code=400, detail="Invalid search mode")
raw_search_results = await compute_search_results(
similarities=similarities,
app=request.app
)
filtered_results = await filter_search(
df=raw_search_results,
filter_string=search_data.filter_string,
)
sorted_results = filtered_results.sort(
by=search_data.sort_col,
descending=search_data.sort_descending
)
output = io.StringIO()
writer = csv.writer(output)
export_df = sorted_results.select([
"meta_doi",
"meta_instrument_name",
"max_similarity",
pl.col("scale_name").list.join("; ").alias("scale_names"),
pl.col("similarity").list.eval(pl.element().round(3).cast(pl.Utf8)).list.join("; ").alias("similarities"),
pl.col("warning_codes").list.join("; ").alias("warnings"),
])
# Write header and rows
writer.writerow(export_df.columns)
writer.writerows(export_df.iter_rows())
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=synthnet_results.csv"}
)