victor7246 commited on
Commit
cdaceed
·
verified ·
1 Parent(s): a814b94

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +25 -23
utils.py CHANGED
@@ -33,6 +33,29 @@ from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
33
 
34
  emb_model = SentenceTransformer("all-MiniLM-L6-v2")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  xls = pd.ExcelFile('SmartClever table explanations.xlsx')
37
  metadata_df = pd.DataFrame()
38
  i = 0
@@ -49,6 +72,8 @@ for sheet_name in xls.sheet_names:
49
 
50
  i += 1
51
 
 
 
52
  def extract_question_type(llm, query):
53
  messages = [
54
  (
@@ -69,29 +94,6 @@ def extract_question_type(llm, query):
69
  else:
70
  return 'unknown'
71
 
72
- class EmbeddingsSearch:
73
- def __init__(self, metadata_df, emb_model):
74
-
75
- self.model = emb_model
76
- self.metadata_df = metadata_df
77
- self.embeddings = self.model.encode(self.metadata_df['desc'].tolist())
78
-
79
- def __call__(self, text: str, topk: int = 5):
80
- q_emb = self.model.encode([text])
81
- distances = cosine_similarity(q_emb, self.embeddings)
82
- idx = np.flip(distances.argsort())[0]
83
- distances.sort()
84
- distances = np.flip(distances)[0]
85
-
86
- results = pd.DataFrame()
87
- results['idx'] = idx.tolist()[:topk]
88
- results['distances'] = distances.tolist()[:topk]
89
-
90
- results['table'] = [
91
- self.metadata_df.loc[i, "table"] for i in results['idx']
92
- ]
93
- return results
94
-
95
  warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
96
 
97
  intermediate_steps_KEY = "intermediate_steps"
 
33
 
34
  emb_model = SentenceTransformer("all-MiniLM-L6-v2")
35
 
36
+ class EmbeddingsSearch:
37
+ def __init__(self, metadata_df, emb_model):
38
+
39
+ self.model = emb_model
40
+ self.metadata_df = metadata_df
41
+ self.embeddings = self.model.encode(self.metadata_df['desc'].tolist())
42
+
43
+ def __call__(self, text: str, topk: int = 5):
44
+ q_emb = self.model.encode([text])
45
+ distances = cosine_similarity(q_emb, self.embeddings)
46
+ idx = np.flip(distances.argsort())[0]
47
+ distances.sort()
48
+ distances = np.flip(distances)[0]
49
+
50
+ results = pd.DataFrame()
51
+ results['idx'] = idx.tolist()[:topk]
52
+ results['distances'] = distances.tolist()[:topk]
53
+
54
+ results['table'] = [
55
+ self.metadata_df.loc[i, "table"] for i in results['idx']
56
+ ]
57
+ return results
58
+
59
  xls = pd.ExcelFile('SmartClever table explanations.xlsx')
60
  metadata_df = pd.DataFrame()
61
  i = 0
 
72
 
73
  i += 1
74
 
75
+ table_search = EmbeddingsSearch(metadata_df=metadata_df, emb_model=emb_model)
76
+
77
  def extract_question_type(llm, query):
78
  messages = [
79
  (
 
94
  else:
95
  return 'unknown'
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
98
 
99
  intermediate_steps_KEY = "intermediate_steps"