kamkol commited on
Commit
cf0b4fb
·
1 Parent(s): abe7dd0

Directly calculate embedding similarity

Browse files
Files changed (3) hide show
  1. app/app.py +35 -5
  2. process_data.py +43 -4
  3. requirements.txt +3 -1
app/app.py CHANGED
@@ -16,7 +16,9 @@ from langchain_core.tools import tool
16
  from langchain_openai import ChatOpenAI
17
  from langchain_community.tools.arxiv.tool import ArxivQueryRun
18
  from langchain.schema.output_parser import StrOutputParser
19
- from sentence_transformers import SentenceTransformer
 
 
20
  from langchain_core.vectorstores import VectorStore
21
  from langchain_core.documents import Document
22
  from langgraph.graph import StateGraph, END
@@ -105,9 +107,10 @@ def find_processed_data():
105
  """Find the processed_data directory path"""
106
  # Check common locations
107
  possible_paths = [
 
108
  "data/processed_data",
109
  "app/data/processed_data",
110
- "/data/processed_data"
111
  ]
112
 
113
  for path in possible_paths:
@@ -122,6 +125,32 @@ def find_processed_data():
122
 
123
  raise FileNotFoundError("Could not find processed_data directory")
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # Initialize the vectorstore
126
  @st.cache_resource
127
  def initialize_vectorstore():
@@ -146,11 +175,12 @@ def initialize_vectorstore():
146
  except Exception as e:
147
  embedded_docs = []
148
  raise RuntimeError(f"Error loading embedded_docs.pkl: {str(e)}")
149
-
150
- # Initialize embedding model - use SentenceTransformer directly
151
  model_name = "kamkol/ab_testing_finetuned_arctic_ft-36dfff22-0696-40d2-b3bf-268fe2ff2aec"
152
  try:
153
- embedding_model = SentenceTransformer(model_name)
 
154
  except Exception as e:
155
  print(f"Error loading model: {str(e)}")
156
  raise RuntimeError(f"Error initializing SentenceTransformer model: {str(e)}")
 
16
  from langchain_openai import ChatOpenAI
17
  from langchain_community.tools.arxiv.tool import ArxivQueryRun
18
  from langchain.schema.output_parser import StrOutputParser
19
+ from transformers import AutoModel, AutoTokenizer
20
+ import torch
21
+ import torch.nn.functional as F
22
  from langchain_core.vectorstores import VectorStore
23
  from langchain_core.documents import Document
24
  from langgraph.graph import StateGraph, END
 
107
  """Find the processed_data directory path"""
108
  # Check common locations
109
  possible_paths = [
110
+ "/data/processed_data",
111
  "data/processed_data",
112
  "app/data/processed_data",
113
+ "/app/data/processed_data"
114
  ]
115
 
116
  for path in possible_paths:
 
125
 
126
  raise FileNotFoundError("Could not find processed_data directory")
127
 
128
+ class ArcticEmbedder:
129
+ def __init__(self, model_name):
130
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
131
+ self.model = AutoModel.from_pretrained(model_name)
132
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
133
+ self.model.to(self.device)
134
+
135
+ def _mean_pooling(self, model_output, attention_mask):
136
+ token_embeddings = model_output.last_hidden_state
137
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
138
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
139
+
140
+ def encode(self, query):
141
+ encoded_input = self.tokenizer(
142
+ [query],
143
+ padding=True,
144
+ truncation=True,
145
+ return_tensors="pt"
146
+ ).to(self.device)
147
+
148
+ with torch.no_grad():
149
+ model_output = self.model(**encoded_input)
150
+
151
+ embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
152
+ return F.normalize(embeddings, p=2, dim=1).cpu().numpy().flatten().tolist()
153
+
154
  # Initialize the vectorstore
155
  @st.cache_resource
156
  def initialize_vectorstore():
 
175
  except Exception as e:
176
  embedded_docs = []
177
  raise RuntimeError(f"Error loading embedded_docs.pkl: {str(e)}")
178
+
179
+ # Initialize custom embedding model
180
  model_name = "kamkol/ab_testing_finetuned_arctic_ft-36dfff22-0696-40d2-b3bf-268fe2ff2aec"
181
  try:
182
+ embedding_model = ArcticEmbedder(model_name)
183
+
184
  except Exception as e:
185
  print(f"Error loading model: {str(e)}")
186
  raise RuntimeError(f"Error initializing SentenceTransformer model: {str(e)}")
process_data.py CHANGED
@@ -12,7 +12,9 @@ from langchain_community.document_loaders import DirectoryLoader
12
  from langchain_community.document_loaders import PyPDFLoader
13
  from langchain_core.documents import Document
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
- from sentence_transformers import SentenceTransformer
 
 
16
  from langchain_community.vectorstores import Qdrant
17
  from qdrant_client import QdrantClient
18
  from qdrant_client.models import Distance, VectorParams
@@ -58,6 +60,40 @@ def clean_directory(directory_path):
58
  path.mkdir(parents=True, exist_ok=True)
59
  print(f"Created clean directory: {directory_path}")
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def process_pdfs():
62
  """Process PDFs and create vectorstore"""
63
  print("Processing PDFs...")
@@ -150,10 +186,13 @@ def process_pdfs():
150
  with open(processed_data_dir / "chunks.pkl", "wb") as f:
151
  pickle.dump(split_chunks, f)
152
 
153
- # Initialize embedding model using SentenceTransformer directly
 
 
154
  try:
155
- embedding_model = SentenceTransformer("kamkol/ab_testing_finetuned_arctic_ft-36dfff22-0696-40d2-b3bf-268fe2ff2aec")
156
- print("Successfully loaded SentenceTransformer model")
 
157
  except Exception as e:
158
  print(f"Error loading model: {str(e)}")
159
  raise RuntimeError(f"Error initializing SentenceTransformer model: {str(e)}")
 
12
  from langchain_community.document_loaders import PyPDFLoader
13
  from langchain_core.documents import Document
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from transformers import AutoModel, AutoTokenizer
16
+ import torch
17
+ import torch.nn.functional as F
18
  from langchain_community.vectorstores import Qdrant
19
  from qdrant_client import QdrantClient
20
  from qdrant_client.models import Distance, VectorParams
 
60
  path.mkdir(parents=True, exist_ok=True)
61
  print(f"Created clean directory: {directory_path}")
62
 
63
+ class ArcticEmbedder:
64
+ def __init__(self, model_name):
65
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
66
+ self.model = AutoModel.from_pretrained(model_name)
67
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ self.model.to(self.device)
69
+
70
+ def _mean_pooling(self, model_output, attention_mask):
71
+ token_embeddings = model_output.last_hidden_state
72
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
73
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
74
+
75
+ def encode(self, texts, batch_size=32):
76
+ all_embeddings = []
77
+ for i in range(0, len(texts), batch_size):
78
+ batch = texts[i:i+batch_size]
79
+
80
+ encoded_input = self.tokenizer(
81
+ batch,
82
+ padding=True,
83
+ truncation=True,
84
+ return_tensors="pt"
85
+ ).to(self.device)
86
+
87
+ with torch.no_grad():
88
+ model_output = self.model(**encoded_input)
89
+
90
+ batch_embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
91
+ batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
92
+
93
+ all_embeddings.append(batch_embeddings.cpu().numpy())
94
+
95
+ return np.concatenate(all_embeddings)
96
+
97
  def process_pdfs():
98
  """Process PDFs and create vectorstore"""
99
  print("Processing PDFs...")
 
186
  with open(processed_data_dir / "chunks.pkl", "wb") as f:
187
  pickle.dump(split_chunks, f)
188
 
189
+
190
+
191
+ # Initialize custom embedding model
192
  try:
193
+ embedding_model = ArcticEmbedder("kamkol/ab_testing_finetuned_arctic_ft-36dfff22-0696-40d2-b3bf-268fe2ff2aec")
194
+ print("Successfully loaded ArcticEmbedder model")
195
+
196
  except Exception as e:
197
  print(f"Error loading model: {str(e)}")
198
  raise RuntimeError(f"Error initializing SentenceTransformer model: {str(e)}")
requirements.txt CHANGED
@@ -10,4 +10,6 @@ tiktoken>=0.6.0
10
  python-dotenv>=1.0.1
11
  qdrant-client>=1.7.0
12
  scipy>=1.10.0
13
- sentence-transformers==2.3.0
 
 
 
10
  python-dotenv>=1.0.1
11
  qdrant-client>=1.7.0
12
  scipy>=1.10.0
13
+ sentence-transformers==2.3.0
14
+ transformers>=4.51.3
15
+ torch>=2.0.1