Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
|
@@ -70,6 +70,29 @@ def process_single_patent(patent_dict):
|
|
| 70 |
"features": rank_by_centrality(top_features),
|
| 71 |
}
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def load_json_file(file_path):
|
| 74 |
"""Load JSON data from a file"""
|
| 75 |
with open(file_path, 'r') as f:
|
|
@@ -145,10 +168,8 @@ def extract_text(content_dict, text_type="full"):
|
|
| 145 |
filtered_dict = process_single_patent(content_dict)
|
| 146 |
all_text = []
|
| 147 |
# Start with abstract for better context at the beginning
|
| 148 |
-
if "
|
| 149 |
-
all_text.append(content_dict["
|
| 150 |
-
# if "pa01" in content_dict:
|
| 151 |
-
# all_text.append(content_dict["pa01"])
|
| 152 |
|
| 153 |
# For claims, paragraphs and features, we take only the top-10 most relevant
|
| 154 |
# Add claims
|
|
@@ -162,6 +183,26 @@ def extract_text(content_dict, text_type="full"):
|
|
| 162 |
all_text.append(paragraph)
|
| 163 |
|
| 164 |
return " ".join(all_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
|
| 167 |
return ""
|
|
|
|
| 70 |
"features": rank_by_centrality(top_features),
|
| 71 |
}
|
| 72 |
|
| 73 |
+
def process_single_patent2(patent_dict):
|
| 74 |
+
def filter_short_texts(texts, min_tokens=5):
|
| 75 |
+
return [text for text in texts if len(text.split()) >= min_tokens]
|
| 76 |
+
|
| 77 |
+
# Filter short texts
|
| 78 |
+
claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
|
| 79 |
+
paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
|
| 80 |
+
features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
|
| 81 |
+
|
| 82 |
+
# Re-rank claims and features directly
|
| 83 |
+
ranked_claims = rank_by_centrality(claims)
|
| 84 |
+
ranked_features = rank_by_centrality(features)
|
| 85 |
+
|
| 86 |
+
# Only filter (cluster + rank) for paragraphs
|
| 87 |
+
filtered_paragraphs = cluster_and_rank(paragraphs)
|
| 88 |
+
ranked_paragraphs = rank_by_centrality(filtered_paragraphs)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"claims": ranked_claims,
|
| 92 |
+
"paragraphs": ranked_paragraphs,
|
| 93 |
+
"features": ranked_features,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
def load_json_file(file_path):
|
| 97 |
"""Load JSON data from a file"""
|
| 98 |
with open(file_path, 'r') as f:
|
|
|
|
| 168 |
filtered_dict = process_single_patent(content_dict)
|
| 169 |
all_text = []
|
| 170 |
# Start with abstract for better context at the beginning
|
| 171 |
+
if "pa01" in content_dict:
|
| 172 |
+
all_text.append(content_dict["pa01"])
|
|
|
|
|
|
|
| 173 |
|
| 174 |
# For claims, paragraphs and features, we take only the top-10 most relevant
|
| 175 |
# Add claims
|
|
|
|
| 183 |
all_text.append(paragraph)
|
| 184 |
|
| 185 |
return " ".join(all_text)
|
| 186 |
+
|
| 187 |
+
elif text_type == "smart2":
|
| 188 |
+
filtered_dict = process_single_patent2(content_dict)
|
| 189 |
+
all_text = []
|
| 190 |
+
# Start with abstract for better context at the beginning
|
| 191 |
+
if "pa01" in content_dict:
|
| 192 |
+
all_text.append(content_dict["pa01"])
|
| 193 |
+
|
| 194 |
+
# For claims, paragraphs and features, we take only the top-10 most relevant
|
| 195 |
+
# Add claims
|
| 196 |
+
for claim in filtered_dict["claims"][:10]:
|
| 197 |
+
all_text.append(claim)
|
| 198 |
+
# Add paragraphs
|
| 199 |
+
for paragraph in filtered_dict["paragraphs"][:10]:
|
| 200 |
+
all_text.append(paragraph)
|
| 201 |
+
# Add features
|
| 202 |
+
for feature in filtered_dict["features"][:10]:
|
| 203 |
+
all_text.append(feature)
|
| 204 |
+
|
| 205 |
+
return " ".join(all_text)
|
| 206 |
|
| 207 |
|
| 208 |
return ""
|