timofey commited on
Commit
ddbf83e
·
1 Parent(s): 64ea315
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import os
4
+ import spaces
5
+
6
+ MODEL_NAME = os.getenv("MODEL_NAME", "timofeyk/roberta-query-router-ecommerce")
7
+ try:
8
+ router_pipeline = pipeline(
9
+ "text-classification",
10
+ model=MODEL_NAME,
11
+ return_all_scores=True
12
+ )
13
+ router_pipeline.to('cuda')
14
+ except Exception as e:
15
+ print(f"Error loading model: {e}")
16
+ router_pipeline = None
17
+
18
+ @spaces.GPU
19
+ def classify_query(query_text):
20
+ if not router_pipeline:
21
+ return {"Error": "Model could not be loaded. Check Space logs for details."}
22
+
23
+ if not query_text or not query_text.strip():
24
+ return {"Vector Search": 0.0, "Lexical Search": 0.0}
25
+
26
+ predictions = router_pipeline(query_text)[0]
27
+
28
+ scores = {item['label']: item['score'] for item in predictions}
29
+
30
+ output_scores = {
31
+ "Vector Search (Conceptual)": scores.get('vector_search', 0.0),
32
+ "Lexical Search (Specific)": scores.get('lexical_search', 0.0)
33
+ }
34
+
35
+ return output_scores
36
+
37
+ title = "E-commerce Query Router"
38
+ description = """
39
+ ### Is the query conceptual or specific?
40
+ Enter an e-commerce query to determine if it's better for **vector search** (conceptual, broad) or **lexical search** (specific, keyword-based). The model will output the weights for each search type.
41
+ - **Conceptual Query Example:** "summer vibes clothing"
42
+ - **Specific Query Example:** "nike air force 1 size 10"
43
+ """
44
+ examples = [
45
+ ["father day gift"],
46
+ ["16x16 pillow cover"],
47
+ ["something to wear for a wedding"],
48
+ ["logitech mx master 3s mouse"],
49
+ ["comfortable office chair"],
50
+ ]
51
+
52
+ app = gr.Interface(
53
+ fn=classify_query,
54
+ inputs=gr.Textbox(
55
+ lines=1,
56
+ label="E-commerce Search Query",
57
+ placeholder="Enter your product query here..."
58
+ ),
59
+ outputs=gr.Label(
60
+ label="Search Type Weights",
61
+ num_top_classes=2
62
+ ),
63
+ title=title,
64
+ description=description,
65
+ examples=examples,
66
+ theme=gr.themes.Soft(),
67
+ allow_flagging="never"
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ app.launch()