Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
|
@@ -223,6 +223,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
|
|
| 223 |
return [idx for idx, _ in indexed_scores]
|
| 224 |
|
| 225 |
def main():
|
|
|
|
| 226 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
| 227 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
| 228 |
help='Path to pre-ranking JSON file')
|
|
@@ -248,7 +249,7 @@ def main():
|
|
| 248 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
| 249 |
help='Device to use (cuda/cpu)')
|
| 250 |
parser.add_argument('--base_dir', type=str,
|
| 251 |
-
default='datasets',
|
| 252 |
help='Base directory for data files')
|
| 253 |
|
| 254 |
args = parser.parse_args()
|
|
|
|
| 223 |
return [idx for idx, _ in indexed_scores]
|
| 224 |
|
| 225 |
def main():
|
| 226 |
+
base_directory = os.getcwd()
|
| 227 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
| 228 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
| 229 |
help='Path to pre-ranking JSON file')
|
|
|
|
| 249 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
| 250 |
help='Device to use (cuda/cpu)')
|
| 251 |
parser.add_argument('--base_dir', type=str,
|
| 252 |
+
default=f'{base_directory}/datasets',
|
| 253 |
help='Base directory for data files')
|
| 254 |
|
| 255 |
args = parser.parse_args()
|