| import h5py |
| from sentence_transformers import SentenceTransformer, CrossEncoder, util |
| import os |
| import torch |
| import pandas as pd |
|
|
| from src.utils.path_utils import get_project_root |
|
|
|
|
| class SemanticSimilarity: |
| def __init__( |
| self, |
| train_embeddings_file, |
| test_embeddings_file, |
| train_csv_path=None, |
| test_csv_path=None, |
| train_df=None, |
| test_df=None, |
| ): |
| |
| self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1") |
| self.bi_encoder.max_seq_length = 512 |
|
|
| self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
| self.train_embeddings, self.train_ids = self._load_embeddings( |
| train_embeddings_file |
| ) |
| self.test_embeddings, self.test_ids = self._load_embeddings( |
| test_embeddings_file |
| ) |
|
|
| |
| self.train_csv = ( |
| train_df if train_df is not None else pd.read_csv(train_csv_path) |
| ) |
| self.test_csv = test_df if test_df is not None else pd.read_csv(test_csv_path) |
|
|
| def _load_embeddings(self, h5_file_path): |
| """ |
| Load embeddings and IDs from the HDF5 file |
| """ |
| with h5py.File(h5_file_path, "r") as h5_file: |
| embeddings = torch.tensor(h5_file["embeddings"][:], dtype=torch.float16) |
| ids = list(h5_file["ids"][:]) |
|
|
| return embeddings, ids |
|
|
| def search(self, query, top_k): |
| |
| |
| question_embedding = self.bi_encoder.encode(query, convert_to_tensor=True) |
| question_embedding = question_embedding.to(dtype=torch.float16) |
| |
|
|
| hits_train = util.semantic_search( |
| question_embedding, self.train_embeddings, top_k=top_k * 5 |
| ) |
| hits_train = hits_train[0] |
| |
| hits_test = util.semantic_search( |
| question_embedding, self.test_embeddings, top_k=top_k * 5 |
| ) |
| hits_test = hits_test[0] |
| |
|
|
| |
| |
| cross_inp_train = [ |
| [query, self.train_csv["evidence_enriched"][hit["corpus_id"]]] |
| for hit in hits_train |
| ] |
| cross_scores_train = self.cross_encoder.predict(cross_inp_train) |
|
|
| cross_inp_test = [ |
| [query, self.test_csv["evidence_enriched"][hit["corpus_id"]]] |
| for hit in hits_test |
| ] |
| cross_scores_test = self.cross_encoder.predict(cross_inp_test) |
|
|
| |
| for idx in range(len(cross_scores_train)): |
| hits_train[idx]["cross-score"] = cross_scores_train[idx] |
|
|
| for idx in range(len(cross_scores_test)): |
| hits_test[idx]["cross-score"] = cross_scores_test[idx] |
|
|
| hits_train_cross_encoder = sorted( |
| hits_train, key=lambda x: x.get("cross-score"), reverse=True |
| ) |
| hits_train_cross_encoder = hits_train_cross_encoder[: top_k * 5] |
| hits_test_cross_encoder = sorted( |
| hits_test, key=lambda x: x.get("cross-score"), reverse=True |
| ) |
| hits_test_cross_encoder = hits_test_cross_encoder[: top_k * 5] |
|
|
| results = [ |
| (self.train_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score")) |
| for hit in hits_train_cross_encoder |
| ] + [ |
| (self.test_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score")) |
| for hit in hits_test_cross_encoder |
| ] |
|
|
| |
| unique_scores = set() |
| filtered_results = [] |
|
|
| |
| for id_, score in sorted(results, key=lambda x: x[1], reverse=True): |
| if score not in unique_scores: |
| unique_scores.add(score) |
| filtered_results.append((id_, score)) |
|
|
| if ( |
| len(filtered_results) == top_k |
| ): |
| break |
|
|
| return filtered_results |
|
|
|
|
| class TextCorpus: |
| def __init__(self, data_dir, split): |
| self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1") |
| self.split = split |
| self.data_dir = data_dir |
|
|
| def encode_corpus(self): |
| """ |
| Encode the corpus (evidence_enriched column for both train and test) and store the embeddings. |
| """ |
| file_path = os.path.join(self.data_dir, f"{self.split}_enriched.csv") |
| df = pd.read_csv(file_path) |
|
|
| |
| evidence_enriched = df["evidence_enriched"].tolist() |
| ids = df["id"].tolist() |
|
|
| |
| embeddings = self.bi_encoder.encode(evidence_enriched, convert_to_tensor=True) |
|
|
| |
| h5_file_path = os.path.join(get_project_root(), f"{self.split}_embeddings.h5") |
|
|
| with h5py.File(h5_file_path, "w") as h5_file: |
| h5_file.create_dataset( |
| "embeddings", data=embeddings.numpy(), dtype="float16" |
| ) |
|
|
| h5_file.create_dataset( |
| "ids", |
| data=[f"{self.split}_{id}" for id in ids], |
| dtype=h5py.string_dtype(), |
| ) |
|
|
| print(f"Embeddings saved to {h5_file_path}") |
|
|
|
|
| if __name__ == "__main__": |
| import time |
|
|
| start_time = time.time() |
| project_root = get_project_root() |
| data_dir = os.path.join(project_root, "data", "preprocessed") |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| train_csv_path = os.path.join(data_dir, "train_enriched.csv") |
| test_csv_path = os.path.join(data_dir, "test_enriched.csv") |
| train_embeddings_file = os.path.join(project_root, "train_embeddings.h5") |
| test_embeddings_file = os.path.join(project_root, "test_embeddings.h5") |
|
|
| |
| similarity = SemanticSimilarity( |
| train_embeddings_file=train_embeddings_file, |
| test_embeddings_file=test_embeddings_file, |
| train_csv_path=train_csv_path, |
| test_csv_path=test_csv_path, |
| ) |
|
|
| |
| train_df = pd.read_csv(train_csv_path) |
| first_query = train_df["claim_enriched"].iloc[2] |
|
|
| |
| top_k = 5 |
|
|
| |
| results = similarity.search(query=first_query, top_k=top_k) |
| finish_time = time.time() - start_time |
| |
|
|
| print(results) |
| print(f"Finish time: {finish_time}") |
|
|