roshbeed commited on
Commit
28939a3
·
verified ·
1 Parent(s): 9c79fa8

Upload src/generate_triples.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/generate_triples.py +77 -0
src/generate_triples.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import random
3
+ import json
4
+ from tqdm import tqdm
5
+
6
+ def generate_triples(): # Remove max_examples_per_split
7
+ # Load the dataset
8
+ print("Loading MS MARCO dataset...")
9
+ dataset = load_dataset("ms_marco", "v1.1")
10
+
11
+ # Dictionary to store our triples
12
+ triples = {
13
+ 'train': [],
14
+ 'validation': [],
15
+ 'test': []
16
+ }
17
+
18
+ # Process each split
19
+ for split in ['train', 'validation', 'test']:
20
+ print(f"\nProcessing {split} split...")
21
+ split_data = dataset[split] # Use all data
22
+
23
+ # First, collect all passages for negative sampling
24
+ all_passages = []
25
+ for example in split_data:
26
+ passages = example['passages']['passage_text']
27
+ all_passages.extend(passages)
28
+ all_passages = list(set(all_passages)) # Remove duplicates
29
+ print(f"Total unique passages for negative sampling: {len(all_passages)}")
30
+
31
+ # Generate triples
32
+ for example in tqdm(split_data, desc=f"Generating triples for {split}"):
33
+ query = example['query']
34
+
35
+ # Get relevant passages
36
+ passages = example['passages']['passage_text']
37
+ relevance = example['passages']['is_selected']
38
+
39
+ # For each relevant passage, create a triple
40
+ for i, (passage, is_relevant) in enumerate(zip(passages, relevance)):
41
+ if is_relevant: # This is a positive document
42
+ # Sample a negative document
43
+ negative_passages = [p for p in all_passages if p != passage]
44
+ if negative_passages: # Make sure we have negative samples
45
+ negative_doc = random.choice(negative_passages)
46
+
47
+ # Create the triple
48
+ triple = {
49
+ 'query': query,
50
+ 'positive_doc': passage,
51
+ 'negative_doc': negative_doc
52
+ }
53
+ triples[split].append(triple)
54
+
55
+ print(f"Generated {len(triples[split])} triples for {split} split")
56
+
57
+ # Save the triples
58
+ print("\nSaving triples...")
59
+ with open('triples_small.json', 'w') as f: # Changed filename to indicate it's a small dataset
60
+ json.dump(triples, f, indent=2)
61
+
62
+ # Print some statistics and examples
63
+ print("\nTriple generation complete!")
64
+ for split in ['train', 'validation', 'test']:
65
+ print(f"\n{split.upper()} split:")
66
+ print(f"Number of triples: {len(triples[split])}")
67
+
68
+ # Show a sample triple
69
+ if triples[split]:
70
+ sample = triples[split][0]
71
+ print("\nSample triple:")
72
+ print(f"Query: {sample['query']}")
73
+ print(f"\nPositive document: {sample['positive_doc'][:200]}...")
74
+ print(f"\nNegative document: {sample['negative_doc'][:200]}...")
75
+
76
+ if __name__ == "__main__":
77
+ generate_triples() # Use all data