Amii2410 commited on
Commit
5bbbd7a
·
verified ·
1 Parent(s): 34a9978

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer, util
3
+ import networkx as nx
4
+
5
+ # Load pre-trained model (runs only once at startup)
6
+ model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
7
+
8
+ def group_duplicates_api(complaints, threshold=0.7):
9
+ """
10
+ Input:
11
+ complaints: list of complaint strings
12
+ threshold: float between 0-1 (default 0.7)
13
+ Output:
14
+ List of groups, where each group is a list of complaint texts
15
+ """
16
+ if not complaints:
17
+ return []
18
+
19
+ embeddings = model.encode(complaints, convert_to_tensor=True)
20
+ cosine_scores = util.pytorch_cos_sim(embeddings, embeddings)
21
+
22
+ G = nx.Graph()
23
+ G.add_nodes_from(range(len(complaints)))
24
+
25
+ for i in range(len(complaints)):
26
+ for j in range(i + 1, len(complaints)):
27
+ if cosine_scores[i][j].item() >= threshold:
28
+ G.add_edge(i, j)
29
+
30
+ duplicate_groups = list(nx.connected_components(G))
31
+ results = []
32
+ for group in duplicate_groups:
33
+ group_texts = [complaints[idx] for idx in group]
34
+ results.append(group_texts)
35
+
36
+ return results
37
+
38
+ # Gradio interface
39
+ demo = gr.Interface(
40
+ fn=group_duplicates_api,
41
+ inputs=[
42
+ gr.Textbox(lines=10, placeholder="Enter complaints separated by newline", label="Complaints"),
43
+ gr.Slider(0.5, 0.95, value=0.7, step=0.01, label="Similarity Threshold")
44
+ ],
45
+ outputs=gr.JSON(label="Duplicate Groups"),
46
+ title="Duplicate Complaint Grouping API",
47
+ description="Paste multiple complaints (one per line) and get grouped duplicates."
48
+ )
49
+
50
+ if __name__ == "__main__":
51
+ demo.launch()