File size: 3,040 Bytes
28939a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from datasets import load_dataset
import random
import json
from tqdm import tqdm
def generate_triples(): # Remove max_examples_per_split
# Load the dataset
print("Loading MS MARCO dataset...")
dataset = load_dataset("ms_marco", "v1.1")
# Dictionary to store our triples
triples = {
'train': [],
'validation': [],
'test': []
}
# Process each split
for split in ['train', 'validation', 'test']:
print(f"\nProcessing {split} split...")
split_data = dataset[split] # Use all data
# First, collect all passages for negative sampling
all_passages = []
for example in split_data:
passages = example['passages']['passage_text']
all_passages.extend(passages)
all_passages = list(set(all_passages)) # Remove duplicates
print(f"Total unique passages for negative sampling: {len(all_passages)}")
# Generate triples
for example in tqdm(split_data, desc=f"Generating triples for {split}"):
query = example['query']
# Get relevant passages
passages = example['passages']['passage_text']
relevance = example['passages']['is_selected']
# For each relevant passage, create a triple
for i, (passage, is_relevant) in enumerate(zip(passages, relevance)):
if is_relevant: # This is a positive document
# Sample a negative document
negative_passages = [p for p in all_passages if p != passage]
if negative_passages: # Make sure we have negative samples
negative_doc = random.choice(negative_passages)
# Create the triple
triple = {
'query': query,
'positive_doc': passage,
'negative_doc': negative_doc
}
triples[split].append(triple)
print(f"Generated {len(triples[split])} triples for {split} split")
# Save the triples
print("\nSaving triples...")
with open('triples_small.json', 'w') as f: # Changed filename to indicate it's a small dataset
json.dump(triples, f, indent=2)
# Print some statistics and examples
print("\nTriple generation complete!")
for split in ['train', 'validation', 'test']:
print(f"\n{split.upper()} split:")
print(f"Number of triples: {len(triples[split])}")
# Show a sample triple
if triples[split]:
sample = triples[split][0]
print("\nSample triple:")
print(f"Query: {sample['query']}")
print(f"\nPositive document: {sample['positive_doc'][:200]}...")
print(f"\nNegative document: {sample['negative_doc'][:200]}...")
if __name__ == "__main__":
generate_triples() # Use all data |