heerjtdev commited on
Commit
703b939
·
verified ·
1 Parent(s): 3a2aaad

Upload Data_augmentation.py

Browse files
Files changed (1) hide show
  1. Data_augmentation.py +105 -0
Data_augmentation.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import os
4
+
5
+ # --- Configuration ---
6
+ # The name of the file to load and save to.
7
+ INPUT_FILE = "unified_training_data_bluuhhhhh.json"
8
+ # The maximum allowed deviation for the shift in the x and y directions.
9
+ # A range of +/- 5 is used to keep the change subtle but effective.
10
+ MAX_SHIFT = 10
11
+ # The coordinate boundary limit (assuming coordinates are scaled 0-1000)
12
+ MAX_COORD = 1000
13
+ MIN_COORD = 0
14
+ # Number of augmented copies to create (1 means the original dataset size is doubled)
15
+ NUM_AUGMENTATION_COPIES = 1
16
+
17
+
18
+ def clip_coord(coord):
19
+ """Ensures a coordinate stays within the 0 to MAX_COORD boundary."""
20
+ return max(MIN_COORD, min(MAX_COORD, coord))
21
+
22
+
23
+ def augment_data(data, shift_x, shift_y):
24
+ """
25
+ Applies a uniform translation shift to all bounding boxes in the dataset
26
+ and returns the new augmented list of tokens.
27
+
28
+ The shift_x and shift_y are the same for all tokens in this copy,
29
+ preserving the crucial relative layout structure.
30
+ """
31
+ augmented_data = []
32
+
33
+ for item in data:
34
+ # Create a deep copy of the item to avoid modifying the original data in place
35
+ new_item = item.copy()
36
+
37
+ # Bounding box coordinates: [x_min, y_min, x_max, y_max]
38
+ bbox = new_item['bbox']
39
+
40
+ # Apply the uniform shift and clip the coordinates
41
+ new_bbox = [
42
+ clip_coord(bbox[0] + shift_x), # x_min
43
+ clip_coord(bbox[1] + shift_y), # y_min
44
+ clip_coord(bbox[2] + shift_x), # x_max
45
+ clip_coord(bbox[3] + shift_y) # y_max
46
+ ]
47
+
48
+ new_item['bbox'] = new_bbox
49
+ augmented_data.append(new_item)
50
+
51
+ return augmented_data
52
+
53
+
54
+ def process_dataset():
55
+ """Loads the original data, performs augmentation, and saves the combined data."""
56
+ if not os.path.exists(INPUT_FILE):
57
+ print(f"Error: Input file '{INPUT_FILE}' not found.")
58
+ print("Please ensure your uploaded JSON file is available and named correctly.")
59
+ return
60
+
61
+ print(f"Loading data from {INPUT_FILE}...")
62
+ try:
63
+ with open(INPUT_FILE, 'r') as f:
64
+ # Assuming the JSON file is a list of token objects
65
+ original_data = json.load(f)
66
+ except json.JSONDecodeError:
67
+ print(f"Error: Failed to decode JSON from '{INPUT_FILE}'. Check file format.")
68
+ return
69
+ except Exception as e:
70
+ print(f"An error occurred while reading the file: {e}")
71
+ return
72
+
73
+ print(f"Original dataset size: {len(original_data)} tokens.")
74
+
75
+ all_combined_data = original_data.copy()
76
+
77
+ for i in range(NUM_AUGMENTATION_COPIES):
78
+ # 1. Choose a uniform shift for the entire dataset copy
79
+ # This is the core spatial jittering logic.
80
+ shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
81
+ shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
82
+
83
+ print(f"\nCreating augmented copy #{i + 1} with uniform shift (X: {shift_x}, Y: {shift_y})...")
84
+
85
+ # 2. Perform the augmentation
86
+ augmented_copy = augment_data(original_data, shift_x, shift_y)
87
+
88
+ # 3. Append the augmented data to the combined list
89
+ all_combined_data.extend(augmented_copy)
90
+
91
+ print(f"\nAugmentation complete. Total dataset size: {len(all_combined_data)} tokens.")
92
+
93
+ # 4. Save the combined (original + augmented) data back to the file
94
+ print(f"Saving combined data back to {INPUT_FILE}...")
95
+ try:
96
+ with open(INPUT_FILE, 'w') as f:
97
+ # Use indent for readability
98
+ json.dump(all_combined_data, f, indent=2)
99
+ print("Successfully updated the dataset with augmented data.")
100
+ except Exception as e:
101
+ print(f"An error occurred while writing the file: {e}")
102
+
103
+
104
+ if __name__ == "__main__":
105
+ process_dataset()