win2win commited on
Commit
914ad9a
·
verified ·
1 Parent(s): 784bcca

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +42 -0
predict.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ from sklearn.neighbors import NearestNeighbors
6
+
7
+ # Load model and data
8
+ model = SentenceTransformer('models/ad_categorizer')
9
+ with open('data/listings.json') as f:
10
+ listings = json.load(f)
11
+
12
+ # Prepare embeddings
13
+ texts = [item['text'] for item in listings]
14
+ embeddings = model.encode(texts)
15
+ categories = [item['category'] for item in listings]
16
+
17
+ # Create search index
18
+ nn = NearestNeighbors(n_neighbors=1).fit(embeddings)
19
+
20
+ def categorize(text):
21
+ # Encode query
22
+ query_embedding = model.encode(text)
23
+
24
+ # Find nearest match
25
+ _, indices = nn.kneighbors([query_embedding])
26
+ best_match = listings[indices[0][0]]
27
+
28
+ return {
29
+ "category": best_match['category'],
30
+ "category_id": best_match['category_id'],
31
+ "similar_listing": best_match['text']
32
+ }
33
+
34
+ # Gradio interface
35
+ demo = gr.Interface(
36
+ fn=categorize,
37
+ inputs=gr.Textbox(label="Ad Listing"),
38
+ outputs=gr.JSON(label="Prediction"),
39
+ examples=json.load(open('data/test_cases.json'))
40
+ )
41
+
42
+ demo.launch()