Amii2410 commited on
Commit
d671e56
·
verified ·
1 Parent(s): cfd3db5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -2,23 +2,32 @@ import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util
3
  import networkx as nx
4
 
5
- # Load pre-trained model (runs only once at startup)
6
  model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
7
 
8
  def group_duplicates_api(complaints, threshold=0.7):
9
  """
10
- Input:
11
- complaints: list of complaint strings
12
- threshold: float between 0-1 (default 0.7)
13
- Output:
14
- List of groups, where each group is a list of complaint texts
15
  """
 
16
  if not complaints:
17
  return []
18
 
 
 
 
 
 
 
 
 
 
19
  embeddings = model.encode(complaints, convert_to_tensor=True)
20
  cosine_scores = util.pytorch_cos_sim(embeddings, embeddings)
21
 
 
22
  G = nx.Graph()
23
  G.add_nodes_from(range(len(complaints)))
24
 
@@ -27,15 +36,12 @@ def group_duplicates_api(complaints, threshold=0.7):
27
  if cosine_scores[i][j].item() >= threshold:
28
  G.add_edge(i, j)
29
 
 
30
  duplicate_groups = list(nx.connected_components(G))
31
- results = []
32
- for group in duplicate_groups:
33
- group_texts = [complaints[idx] for idx in group]
34
- results.append(group_texts)
35
-
36
  return results
37
 
38
- # Gradio interface
39
  demo = gr.Interface(
40
  fn=group_duplicates_api,
41
  inputs=[
@@ -44,8 +50,9 @@ demo = gr.Interface(
44
  ],
45
  outputs=gr.JSON(label="Duplicate Groups"),
46
  title="Duplicate Complaint Grouping API",
47
- description="Paste multiple complaints (one per line) and get grouped duplicates."
48
  )
49
 
50
  if __name__ == "__main__":
51
  demo.launch()
 
 
2
  from sentence_transformers import SentenceTransformer, util
3
  import networkx as nx
4
 
5
+ # Load the SentenceTransformer model once at startup
6
  model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
7
 
8
  def group_duplicates_api(complaints, threshold=0.7):
9
  """
10
+ Groups similar/duplicate complaints into clusters.
11
+ complaints: multiline string or list of strings
12
+ threshold: similarity score between 0 and 1
 
 
13
  """
14
+ # Handle empty input
15
  if not complaints:
16
  return []
17
 
18
+ # If using the textbox input, split by newline
19
+ if isinstance(complaints, str):
20
+ complaints = [c.strip() for c in complaints.split("\n") if c.strip()]
21
+
22
+ # If fewer than 2 complaints, nothing to compare
23
+ if len(complaints) < 2:
24
+ return [[c] for c in complaints]
25
+
26
+ # Compute embeddings and cosine similarities
27
  embeddings = model.encode(complaints, convert_to_tensor=True)
28
  cosine_scores = util.pytorch_cos_sim(embeddings, embeddings)
29
 
30
+ # Build similarity graph
31
  G = nx.Graph()
32
  G.add_nodes_from(range(len(complaints)))
33
 
 
36
  if cosine_scores[i][j].item() >= threshold:
37
  G.add_edge(i, j)
38
 
39
+ # Extract connected components as duplicate groups
40
  duplicate_groups = list(nx.connected_components(G))
41
+ results = [[complaints[idx] for idx in group] for group in duplicate_groups]
 
 
 
 
42
  return results
43
 
44
+ # Create Gradio interface
45
  demo = gr.Interface(
46
  fn=group_duplicates_api,
47
  inputs=[
 
50
  ],
51
  outputs=gr.JSON(label="Duplicate Groups"),
52
  title="Duplicate Complaint Grouping API",
53
+ description="Paste multiple complaints (one per line) and get grouped duplicates based on semantic similarity."
54
  )
55
 
56
  if __name__ == "__main__":
57
  demo.launch()
58
+