File size: 4,413 Bytes
6ca4b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import io
import csv
import logging
import polars as pl
from fastapi.responses import StreamingResponse
from typing import Annotated
from fastapi import APIRouter, HTTPException, Request, Form
from fastapi.templating import Jinja2Templates

from src.config import config
from src.schemas.search_request import SearchRequest
from src.services.usage_tracking_service import send_tracking_event_to_db
from src.services.search_service import semantic_item_search, semantic_scale_search, compute_search_results, refine_search, filter_search
from src.utils.logging import context_logger

logger = logging.getLogger(__name__)

router = APIRouter()

templates = Jinja2Templates(directory="./src/templates")

@router.post("/")
async def search(
    request: Request,
    search_data: Annotated[SearchRequest, Form()],
):
    logger.info(f"Processing search request: \n{search_data.model_dump_json()}")

    if config.tracking_db.enabled:
        with context_logger(f"Sending search usage data to tracking DB"):
            await send_tracking_event_to_db(
                search_data=search_data
            )
    else:
        logger.warning("Tracking disabled in config / ENV, skipping DB write!")

    if search_data.mode == 'items':
        similarities = await semantic_item_search(
            queries=search_data.items,
            app=request.app
        )
    elif search_data.mode == 'scale':
        similarities = await semantic_scale_search(
            queries=search_data.scale,
            app=request.app
        )
    else:
        raise HTTPException(status_code=400, detail="Invalid search mode")

    raw_search_results = await compute_search_results(
        similarities=similarities,
        app=request.app
    )

    filtered_results = await filter_search(
        df=raw_search_results,
        filter_string=search_data.filter_string,
    )

    search_results = await refine_search(
        df=filtered_results,        
        sort_col=search_data.sort_col,
        sort_descending=search_data.sort_descending,
        page_index=search_data.page_index,
        page_size=search_data.page_size,
    )

    return templates.TemplateResponse(
        name="pages/search/partials/_search_results.jinja",
        context={
            "request": request,
            "search_mode": search_data.mode,
            "search_result": search_results.to_dicts(),
            "filter_string": search_data.filter_string,
            "sort_col": search_data.sort_col,
            "sort_descending": search_data.sort_descending,
            "page_index": search_data.page_index,
            "page_size": search_data.page_size,
            "results_count": filtered_results.height,
        }
    )

@router.post("/export")
async def export_search_results(
    request: Request,
    search_data: Annotated[SearchRequest, Form()],
):
    
    if search_data.mode == 'items':
        similarities = await semantic_item_search(
            queries=search_data.items,
            app=request.app
        )
    elif search_data.mode == 'scale':
        similarities = await semantic_scale_search(
            queries=search_data.scale,
            app=request.app
        )
    else:
        raise HTTPException(status_code=400, detail="Invalid search mode")

    raw_search_results = await compute_search_results(
        similarities=similarities,
        app=request.app
    )

    filtered_results = await filter_search(
        df=raw_search_results,
        filter_string=search_data.filter_string,
    )

    sorted_results = filtered_results.sort(
        by=search_data.sort_col, 
        descending=search_data.sort_descending
    )

    output = io.StringIO()
    writer = csv.writer(output)
    
    export_df = sorted_results.select([
        "meta_doi",
        "meta_instrument_name",
        "max_similarity",
        pl.col("scale_name").list.join("; ").alias("scale_names"),
        pl.col("similarity").list.eval(pl.element().round(3).cast(pl.Utf8)).list.join("; ").alias("similarities"),
        pl.col("warning_codes").list.join("; ").alias("warnings"),
    ])
    
    # Write header and rows
    writer.writerow(export_df.columns)
    writer.writerows(export_df.iter_rows())
    
    output.seek(0)
    
    return StreamingResponse(
        iter([output.getvalue()]),
        media_type="text/csv",
        headers={"Content-Disposition": "attachment; filename=synthnet_results.csv"}
    )