Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
|
@@ -13,10 +13,13 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
| 13 |
|
| 14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 15 |
# Load embedder once
|
| 16 |
-
embedder = SentenceTransformer("
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def embed_text_list(texts):
|
| 19 |
-
return embedder.encode(texts, convert_to_tensor=False, device=device)
|
|
|
|
| 20 |
|
| 21 |
def rank_by_centrality(texts):
|
| 22 |
embeddings = embed_text_list(texts)
|
|
@@ -45,9 +48,12 @@ def cluster_and_rank(texts, threshold=0.75):
|
|
| 45 |
return representative_texts
|
| 46 |
|
| 47 |
def process_single_patent(patent_dict):
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Cluster & rank
|
| 53 |
top_claims = cluster_and_rank(claims)
|
|
@@ -225,6 +231,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
|
|
| 225 |
|
| 226 |
def main():
|
| 227 |
base_directory = os.getcwd()
|
|
|
|
| 228 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
| 229 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
| 230 |
help='Path to pre-ranking JSON file')
|
|
@@ -252,7 +259,7 @@ def main():
|
|
| 252 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
| 253 |
help='Device to use (cuda/cpu)')
|
| 254 |
parser.add_argument('--base_dir', type=str,
|
| 255 |
-
default=f'{base_directory}/
|
| 256 |
help='Base directory for data files')
|
| 257 |
|
| 258 |
args = parser.parse_args()
|
|
|
|
| 13 |
|
| 14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 15 |
# Load embedder once
|
| 16 |
+
# embedder = SentenceTransformer("all-MiniLM-L6-v2").to(device)
|
| 17 |
+
embedder = SentenceTransformer("intfloat/e5-large-v2").to(device)
|
| 18 |
+
|
| 19 |
|
| 20 |
def embed_text_list(texts):
|
| 21 |
+
# return embedder.encode(texts, convert_to_tensor=False, device=device)
|
| 22 |
+
return embedder.encode(["query: your sentence here"], convert_to_tensor=False, device=device)
|
| 23 |
|
| 24 |
def rank_by_centrality(texts):
|
| 25 |
embeddings = embed_text_list(texts)
|
|
|
|
| 48 |
return representative_texts
|
| 49 |
|
| 50 |
def process_single_patent(patent_dict):
|
| 51 |
+
def filter_short_texts(texts, min_tokens=5):
|
| 52 |
+
return [text for text in texts if len(text.split()) >= min_tokens]
|
| 53 |
+
|
| 54 |
+
claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
|
| 55 |
+
paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
|
| 56 |
+
features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
|
| 57 |
|
| 58 |
# Cluster & rank
|
| 59 |
top_claims = cluster_and_rank(claims)
|
|
|
|
| 231 |
|
| 232 |
def main():
|
| 233 |
base_directory = os.getcwd()
|
| 234 |
+
base_directory += "/Patent_Retrieval"
|
| 235 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
| 236 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
| 237 |
help='Path to pre-ranking JSON file')
|
|
|
|
| 259 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
| 260 |
help='Device to use (cuda/cpu)')
|
| 261 |
parser.add_argument('--base_dir', type=str,
|
| 262 |
+
default=f'{base_directory}/datasets',
|
| 263 |
help='Base directory for data files')
|
| 264 |
|
| 265 |
args = parser.parse_args()
|