|
|
from datasets import load_dataset |
|
|
import random |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
|
|
|
def generate_triples(): |
|
|
|
|
|
print("Loading MS MARCO dataset...") |
|
|
dataset = load_dataset("ms_marco", "v1.1") |
|
|
|
|
|
|
|
|
triples = { |
|
|
'train': [], |
|
|
'validation': [], |
|
|
'test': [] |
|
|
} |
|
|
|
|
|
|
|
|
for split in ['train', 'validation', 'test']: |
|
|
print(f"\nProcessing {split} split...") |
|
|
split_data = dataset[split] |
|
|
|
|
|
|
|
|
all_passages = [] |
|
|
for example in split_data: |
|
|
passages = example['passages']['passage_text'] |
|
|
all_passages.extend(passages) |
|
|
all_passages = list(set(all_passages)) |
|
|
print(f"Total unique passages for negative sampling: {len(all_passages)}") |
|
|
|
|
|
|
|
|
for example in tqdm(split_data, desc=f"Generating triples for {split}"): |
|
|
query = example['query'] |
|
|
|
|
|
|
|
|
passages = example['passages']['passage_text'] |
|
|
relevance = example['passages']['is_selected'] |
|
|
|
|
|
|
|
|
for i, (passage, is_relevant) in enumerate(zip(passages, relevance)): |
|
|
if is_relevant: |
|
|
|
|
|
negative_passages = [p for p in all_passages if p != passage] |
|
|
if negative_passages: |
|
|
negative_doc = random.choice(negative_passages) |
|
|
|
|
|
|
|
|
triple = { |
|
|
'query': query, |
|
|
'positive_doc': passage, |
|
|
'negative_doc': negative_doc |
|
|
} |
|
|
triples[split].append(triple) |
|
|
|
|
|
print(f"Generated {len(triples[split])} triples for {split} split") |
|
|
|
|
|
|
|
|
print("\nSaving triples...") |
|
|
with open('triples_small.json', 'w') as f: |
|
|
json.dump(triples, f, indent=2) |
|
|
|
|
|
|
|
|
print("\nTriple generation complete!") |
|
|
for split in ['train', 'validation', 'test']: |
|
|
print(f"\n{split.upper()} split:") |
|
|
print(f"Number of triples: {len(triples[split])}") |
|
|
|
|
|
|
|
|
if triples[split]: |
|
|
sample = triples[split][0] |
|
|
print("\nSample triple:") |
|
|
print(f"Query: {sample['query']}") |
|
|
print(f"\nPositive document: {sample['positive_doc'][:200]}...") |
|
|
print(f"\nNegative document: {sample['negative_doc'][:200]}...") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
generate_triples() |