import io import csv import logging import polars as pl from fastapi.responses import StreamingResponse from typing import Annotated from fastapi import APIRouter, HTTPException, Request, Form from fastapi.templating import Jinja2Templates from src.config import config from src.schemas.search_request import SearchRequest from src.services.usage_tracking_service import send_tracking_event_to_db from src.services.search_service import semantic_item_search, semantic_scale_search, compute_search_results, refine_search, filter_search from src.utils.logging import context_logger logger = logging.getLogger(__name__) router = APIRouter() templates = Jinja2Templates(directory="./src/templates") @router.post("/") async def search( request: Request, search_data: Annotated[SearchRequest, Form()], ): logger.info(f"Processing search request: \n{search_data.model_dump_json()}") if config.tracking_db.enabled: with context_logger(f"Sending search usage data to tracking DB"): await send_tracking_event_to_db( search_data=search_data ) else: logger.warning("Tracking disabled in config / ENV, skipping DB write!") if search_data.mode == 'items': similarities = await semantic_item_search( queries=search_data.items, app=request.app ) elif search_data.mode == 'scale': similarities = await semantic_scale_search( queries=search_data.scale, app=request.app ) else: raise HTTPException(status_code=400, detail="Invalid search mode") raw_search_results = await compute_search_results( similarities=similarities, app=request.app ) filtered_results = await filter_search( df=raw_search_results, filter_string=search_data.filter_string, ) search_results = await refine_search( df=filtered_results, sort_col=search_data.sort_col, sort_descending=search_data.sort_descending, page_index=search_data.page_index, page_size=search_data.page_size, ) return templates.TemplateResponse( name="pages/search/partials/_search_results.jinja", context={ "request": request, "search_mode": search_data.mode, "search_result": search_results.to_dicts(), "filter_string": search_data.filter_string, "sort_col": search_data.sort_col, "sort_descending": search_data.sort_descending, "page_index": search_data.page_index, "page_size": search_data.page_size, "results_count": filtered_results.height, } ) @router.post("/export") async def export_search_results( request: Request, search_data: Annotated[SearchRequest, Form()], ): if search_data.mode == 'items': similarities = await semantic_item_search( queries=search_data.items, app=request.app ) elif search_data.mode == 'scale': similarities = await semantic_scale_search( queries=search_data.scale, app=request.app ) else: raise HTTPException(status_code=400, detail="Invalid search mode") raw_search_results = await compute_search_results( similarities=similarities, app=request.app ) filtered_results = await filter_search( df=raw_search_results, filter_string=search_data.filter_string, ) sorted_results = filtered_results.sort( by=search_data.sort_col, descending=search_data.sort_descending ) output = io.StringIO() writer = csv.writer(output) export_df = sorted_results.select([ "meta_doi", "meta_instrument_name", "max_similarity", pl.col("scale_name").list.join("; ").alias("scale_names"), pl.col("similarity").list.eval(pl.element().round(3).cast(pl.Utf8)).list.join("; ").alias("similarities"), pl.col("warning_codes").list.join("; ").alias("warnings"), ]) # Write header and rows writer.writerow(export_df.columns) writer.writerows(export_df.iter_rows()) output.seek(0) return StreamingResponse( iter([output.getvalue()]), media_type="text/csv", headers={"Content-Disposition": "attachment; filename=synthnet_results.csv"} )