roshbeed commited on
Commit
bfb323a
·
verified ·
1 Parent(s): 36ec7fa

Upload src/extract_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/extract_data.py +84 -0
src/extract_data.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import json
3
+
4
+ def extract_queries_and_documents():
5
+ # Load the dataset
6
+ print("Loading MS MARCO dataset...")
7
+ dataset = load_dataset("ms_marco", "v1.1")
8
+
9
+ # Dictionary to store our extracted data
10
+ extracted_data = {
11
+ 'train': [],
12
+ 'validation': [],
13
+ 'test': []
14
+ }
15
+
16
+ # Extract data from each split
17
+ for split in ['train', 'validation', 'test']:
18
+ print(f"\nProcessing {split} split...")
19
+
20
+ # Process each example
21
+ for example in dataset[split]:
22
+ # Extract query
23
+ query = example['query']
24
+
25
+ # Extract passages and their relevance labels
26
+ passages = example['passages']['passage_text']
27
+ relevance_labels = example['passages']['is_selected'] # 1 if relevant, 0 if not
28
+
29
+ # Create list of (passage, relevance) pairs
30
+ passage_relevance_pairs = list(zip(passages, relevance_labels))
31
+
32
+ # Store the query and its passages with relevance
33
+ extracted_data[split].append({
34
+ 'query': query,
35
+ 'passages_with_relevance': [
36
+ {
37
+ 'passage': passage,
38
+ 'is_relevant': bool(is_relevant) # Convert to boolean for clarity
39
+ }
40
+ for passage, is_relevant in passage_relevance_pairs
41
+ ]
42
+ })
43
+
44
+ # Print progress every 1000 examples
45
+ if len(extracted_data[split]) % 1000 == 0:
46
+ print(f"Processed {len(extracted_data[split])} examples")
47
+
48
+ # Save the extracted data
49
+ print("\nSaving extracted data...")
50
+ with open('extracted_data.json', 'w') as f:
51
+ json.dump(extracted_data, f, indent=2)
52
+
53
+ # Print some statistics
54
+ print("\nExtraction complete!")
55
+ for split in ['train', 'validation', 'test']:
56
+ print(f"\n{split.upper()} split:")
57
+ print(f"Number of queries: {len(extracted_data[split])}")
58
+
59
+ # Calculate relevance statistics
60
+ total_passages = 0
61
+ relevant_passages = 0
62
+ for item in extracted_data[split]:
63
+ for passage_info in item['passages_with_relevance']:
64
+ total_passages += 1
65
+ if passage_info['is_relevant']:
66
+ relevant_passages += 1
67
+
68
+ print(f"Total number of passages: {total_passages}")
69
+ print(f"Number of relevant passages: {relevant_passages}")
70
+ print(f"Percentage of relevant passages: {(relevant_passages/total_passages)*100:.2f}%")
71
+
72
+ # Show a sample
73
+ if extracted_data[split]:
74
+ sample = extracted_data[split][0]
75
+ print("\nSample query:", sample['query'])
76
+ print("Number of passages:", len(sample['passages_with_relevance']))
77
+ print("\nSample passages with relevance:")
78
+ for i, passage_info in enumerate(sample['passages_with_relevance'][:2]): # Show first 2 passages
79
+ print(f"\nPassage {i+1}:")
80
+ print(f"Relevance: {'Relevant' if passage_info['is_relevant'] else 'Not Relevant'}")
81
+ print(f"Preview: {passage_info['passage'][:200]}...")
82
+
83
+ if __name__ == "__main__":
84
+ extract_queries_and_documents()