File size: 2,038 Bytes
ddbf83e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import pipeline
import os
import spaces

MODEL_NAME = os.getenv("MODEL_NAME", "timofeyk/roberta-query-router-ecommerce")
try:
    router_pipeline = pipeline(
        "text-classification",
        model=MODEL_NAME,
        return_all_scores=True
    )
    router_pipeline.to('cuda')
except Exception as e:
    print(f"Error loading model: {e}")
    router_pipeline = None

@spaces.GPU
def classify_query(query_text):
    if not router_pipeline:
        return {"Error": "Model could not be loaded. Check Space logs for details."}

    if not query_text or not query_text.strip():
        return {"Vector Search": 0.0, "Lexical Search": 0.0}

    predictions = router_pipeline(query_text)[0]

    scores = {item['label']: item['score'] for item in predictions}

    output_scores = {
        "Vector Search (Conceptual)": scores.get('vector_search', 0.0),
        "Lexical Search (Specific)": scores.get('lexical_search', 0.0)
    }

    return output_scores

title = "E-commerce Query Router"
description = """
### Is the query conceptual or specific?
Enter an e-commerce query to determine if it's better for **vector search** (conceptual, broad) or **lexical search** (specific, keyword-based). The model will output the weights for each search type.
- **Conceptual Query Example:** "summer vibes clothing"
- **Specific Query Example:** "nike air force 1 size 10"
"""
examples = [
    ["father day gift"],
    ["16x16 pillow cover"],
    ["something to wear for a wedding"],
    ["logitech mx master 3s mouse"],
    ["comfortable office chair"],
]

app = gr.Interface(
    fn=classify_query,
    inputs=gr.Textbox(
        lines=1,
        label="E-commerce Search Query",
        placeholder="Enter your product query here..."
    ),
    outputs=gr.Label(
        label="Search Type Weights",
        num_top_classes=2
    ),
    title=title,
    description=description,
    examples=examples,
    theme=gr.themes.Soft(),
    allow_flagging="never"
)

if __name__ == "__main__":
    app.launch()