heerjtdev commited on
Commit
27b7a20
·
verified ·
1 Parent(s): 703b939

Rename Data_augmentation.py to train.py

Browse files
Files changed (2) hide show
  1. Data_augmentation.py +0 -105
  2. train.py +244 -0
Data_augmentation.py DELETED
@@ -1,105 +0,0 @@
1
- import json
2
- import random
3
- import os
4
-
5
- # --- Configuration ---
6
- # The name of the file to load and save to.
7
- INPUT_FILE = "unified_training_data_bluuhhhhh.json"
8
- # The maximum allowed deviation for the shift in the x and y directions.
9
- # A range of +/- 5 is used to keep the change subtle but effective.
10
- MAX_SHIFT = 10
11
- # The coordinate boundary limit (assuming coordinates are scaled 0-1000)
12
- MAX_COORD = 1000
13
- MIN_COORD = 0
14
- # Number of augmented copies to create (1 means the original dataset size is doubled)
15
- NUM_AUGMENTATION_COPIES = 1
16
-
17
-
18
- def clip_coord(coord):
19
- """Ensures a coordinate stays within the 0 to MAX_COORD boundary."""
20
- return max(MIN_COORD, min(MAX_COORD, coord))
21
-
22
-
23
- def augment_data(data, shift_x, shift_y):
24
- """
25
- Applies a uniform translation shift to all bounding boxes in the dataset
26
- and returns the new augmented list of tokens.
27
-
28
- The shift_x and shift_y are the same for all tokens in this copy,
29
- preserving the crucial relative layout structure.
30
- """
31
- augmented_data = []
32
-
33
- for item in data:
34
- # Create a deep copy of the item to avoid modifying the original data in place
35
- new_item = item.copy()
36
-
37
- # Bounding box coordinates: [x_min, y_min, x_max, y_max]
38
- bbox = new_item['bbox']
39
-
40
- # Apply the uniform shift and clip the coordinates
41
- new_bbox = [
42
- clip_coord(bbox[0] + shift_x), # x_min
43
- clip_coord(bbox[1] + shift_y), # y_min
44
- clip_coord(bbox[2] + shift_x), # x_max
45
- clip_coord(bbox[3] + shift_y) # y_max
46
- ]
47
-
48
- new_item['bbox'] = new_bbox
49
- augmented_data.append(new_item)
50
-
51
- return augmented_data
52
-
53
-
54
- def process_dataset():
55
- """Loads the original data, performs augmentation, and saves the combined data."""
56
- if not os.path.exists(INPUT_FILE):
57
- print(f"Error: Input file '{INPUT_FILE}' not found.")
58
- print("Please ensure your uploaded JSON file is available and named correctly.")
59
- return
60
-
61
- print(f"Loading data from {INPUT_FILE}...")
62
- try:
63
- with open(INPUT_FILE, 'r') as f:
64
- # Assuming the JSON file is a list of token objects
65
- original_data = json.load(f)
66
- except json.JSONDecodeError:
67
- print(f"Error: Failed to decode JSON from '{INPUT_FILE}'. Check file format.")
68
- return
69
- except Exception as e:
70
- print(f"An error occurred while reading the file: {e}")
71
- return
72
-
73
- print(f"Original dataset size: {len(original_data)} tokens.")
74
-
75
- all_combined_data = original_data.copy()
76
-
77
- for i in range(NUM_AUGMENTATION_COPIES):
78
- # 1. Choose a uniform shift for the entire dataset copy
79
- # This is the core spatial jittering logic.
80
- shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
81
- shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
82
-
83
- print(f"\nCreating augmented copy #{i + 1} with uniform shift (X: {shift_x}, Y: {shift_y})...")
84
-
85
- # 2. Perform the augmentation
86
- augmented_copy = augment_data(original_data, shift_x, shift_y)
87
-
88
- # 3. Append the augmented data to the combined list
89
- all_combined_data.extend(augmented_copy)
90
-
91
- print(f"\nAugmentation complete. Total dataset size: {len(all_combined_data)} tokens.")
92
-
93
- # 4. Save the combined (original + augmented) data back to the file
94
- print(f"Saving combined data back to {INPUT_FILE}...")
95
- try:
96
- with open(INPUT_FILE, 'w') as f:
97
- # Use indent for readability
98
- json.dump(all_combined_data, f, indent=2)
99
- print("Successfully updated the dataset with augmented data.")
100
- except Exception as e:
101
- print(f"An error occurred while writing the file: {e}")
102
-
103
-
104
- if __name__ == "__main__":
105
- process_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ import random
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset, DataLoader, random_split
8
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
9
+ from TorchCRF import CRF
10
+ from torch.optim import AdamW
11
+ from tqdm import tqdm
12
+ from sklearn.metrics import precision_recall_fscore_support
13
+
14
+ # --- Configuration ---
15
+ MAX_BBOX_DIMENSION = 1000
16
+ MAX_SHIFT = 30
17
+ AUGMENTATION_FACTOR = 1
18
+ BASE_MODEL_ID = "microsoft/layoutlmv3-base"
19
+
20
+ # -------------------------
21
+ # Step 1: Preprocessing
22
+ # -------------------------
23
+ def preprocess_labelstudio(input_path, output_path):
24
+ with open(input_path, "r", encoding="utf-8") as f:
25
+ data = json.load(f)
26
+
27
+ processed = []
28
+ print(f"🔄 Starting preprocessing of {len(data)} documents...")
29
+
30
+ for item in data:
31
+ words = item["data"]["original_words"]
32
+ bboxes = item["data"]["original_bboxes"]
33
+ labels = ["O"] * len(words)
34
+
35
+ clamped_bboxes = []
36
+ for bbox in bboxes:
37
+ x_min, y_min, x_max, y_max = bbox
38
+ new_x_min = max(0, min(x_min, 1000))
39
+ new_y_min = max(0, min(y_min, 1000))
40
+ new_x_max = max(0, min(x_max, 1000))
41
+ new_y_max = max(0, min(y_max, 1000))
42
+ if new_x_min > new_x_max: new_x_min = new_x_max
43
+ if new_y_min > new_y_max: new_y_min = new_y_max
44
+ clamped_bboxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
45
+
46
+ if "annotations" in item:
47
+ for ann in item["annotations"]:
48
+ for res in ann["result"]:
49
+ if "value" in res and "labels" in res["value"]:
50
+ text = res["value"]["text"]
51
+ tag = res["value"]["labels"][0]
52
+ text_tokens = text.split()
53
+ for i in range(len(words) - len(text_tokens) + 1):
54
+ if words[i:i + len(text_tokens)] == text_tokens:
55
+ labels[i] = f"B-{tag}"
56
+ for j in range(1, len(text_tokens)):
57
+ labels[i + j] = f"I-{tag}"
58
+ break
59
+
60
+ processed.append({"tokens": words, "labels": labels, "bboxes": clamped_bboxes})
61
+
62
+ with open(output_path, "w", encoding="utf-8") as f:
63
+ json.dump(processed, f, indent=2, ensure_ascii=False)
64
+ return output_path
65
+
66
+ # -------------------------
67
+ # Step 1.5: Augmentation
68
+ # -------------------------
69
+ def translate_bbox(bbox, shift_x, shift_y):
70
+ x_min, y_min, x_max, y_max = bbox
71
+ new_x_min = max(0, min(x_min + shift_x, 1000))
72
+ new_y_min = max(0, min(y_min + shift_y, 1000))
73
+ new_x_max = max(0, min(x_max + shift_x, 1000))
74
+ new_y_max = max(0, min(y_max + shift_y, 1000))
75
+ return [new_x_min, new_y_min, new_x_max, new_y_max]
76
+
77
+ def augment_sample(sample):
78
+ shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
79
+ shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
80
+ new_sample = sample.copy()
81
+ new_sample["bboxes"] = [translate_bbox(b, shift_x, shift_y) for b in sample["bboxes"]]
82
+ return new_sample
83
+
84
+ def augment_and_save_dataset(input_json_path, output_json_path):
85
+ with open(input_json_path, 'r', encoding="utf-8") as f:
86
+ training_data = json.load(f)
87
+ augmented_data = []
88
+ for original_sample in training_data:
89
+ augmented_data.append(original_sample)
90
+ for _ in range(AUGMENTATION_FACTOR):
91
+ augmented_data.append(augment_sample(original_sample))
92
+ with open(output_json_path, 'w', encoding="utf-8") as f:
93
+ json.dump(augmented_data, f, indent=2, ensure_ascii=False)
94
+ return output_json_path
95
+
96
+ # -------------------------
97
+ # Step 2: Dataset Class
98
+ # -------------------------
99
+ class LayoutDataset(Dataset):
100
+ def __init__(self, json_path, tokenizer, label2id, max_len=512):
101
+ with open(json_path, "r", encoding="utf-8") as f:
102
+ self.data = json.load(f)
103
+ self.tokenizer = tokenizer
104
+ self.label2id = label2id
105
+ self.max_len = max_len
106
+
107
+ def __len__(self):
108
+ return len(self.data)
109
+
110
+ def __getitem__(self, idx):
111
+ item = self.data[idx]
112
+ words, bboxes, labels = item["tokens"], item["bboxes"], item["labels"]
113
+ encodings = self.tokenizer(words, boxes=bboxes, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
114
+ word_ids = encodings.word_ids(batch_index=0)
115
+ label_ids = []
116
+ for word_id in word_ids:
117
+ if word_id is None:
118
+ label_ids.append(self.label2id["O"])
119
+ else:
120
+ label_ids.append(self.label2id.get(labels[word_id], self.label2id["O"]))
121
+ encodings["labels"] = torch.tensor(label_ids)
122
+ return {key: val.squeeze(0) for key, val in encodings.items()}
123
+
124
+ # -------------------------
125
+ # Step 3: Model Architecture (Non-Linear Head)
126
+ # -------------------------
127
+
128
+ class LayoutLMv3CRF(nn.Module):
129
+ def __init__(self, num_labels):
130
+ super().__init__()
131
+ # Initializing from scratch (Base weights only)
132
+ print(f"🔄 Initializing backbone from {BASE_MODEL_ID}...")
133
+ self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
134
+
135
+ hidden_size = self.layoutlm.config.hidden_size
136
+
137
+ # NON-LINEAR MLP HEAD
138
+ # Replacing the simple Linear layer with a deeper architecture
139
+ self.classifier = nn.Sequential(
140
+ nn.Linear(hidden_size, hidden_size),
141
+ nn.GELU(), # Non-linear activation
142
+ nn.LayerNorm(hidden_size), # Stability for training from scratch
143
+ nn.Dropout(0.1),
144
+ nn.Linear(hidden_size, num_labels)
145
+ )
146
+
147
+ self.crf = CRF(num_labels)
148
+
149
+ def forward(self, input_ids, bbox, attention_mask, labels=None):
150
+ outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
151
+ sequence_output = outputs.last_hidden_state
152
+
153
+ # Pass through the new non-linear head
154
+ emissions = self.classifier(sequence_output)
155
+
156
+ if labels is not None:
157
+ log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
158
+ return -log_likelihood.mean()
159
+ else:
160
+ return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
161
+
162
+ # -------------------------
163
+ # Step 4: Training + Evaluation
164
+ # -------------------------
165
+ def train_one_epoch(model, dataloader, optimizer, device):
166
+ model.train()
167
+ total_loss = 0
168
+ for batch in tqdm(dataloader, desc="Training"):
169
+ batch = {k: v.to(device) for k, v in batch.items()}
170
+ labels = batch.pop("labels")
171
+ optimizer.zero_grad()
172
+ loss = model(**batch, labels=labels)
173
+ loss.backward()
174
+ optimizer.step()
175
+ total_loss += loss.item()
176
+ return total_loss / len(dataloader)
177
+
178
+ def evaluate(model, dataloader, device, id2label):
179
+ model.eval()
180
+ all_preds, all_labels = [], []
181
+ with torch.no_grad():
182
+ for batch in tqdm(dataloader, desc="Evaluating"):
183
+ batch = {k: v.to(device) for k, v in batch.items()}
184
+ labels = batch.pop("labels").cpu().numpy()
185
+ preds = model(**batch)
186
+ for p, l, mask in zip(preds, labels, batch["attention_mask"].cpu().numpy()):
187
+ valid = mask == 1
188
+ l_valid = l[valid].tolist()
189
+ all_labels.extend(l_valid)
190
+ all_preds.extend(p[:len(l_valid)])
191
+ precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
192
+ return precision, recall, f1
193
+
194
+ # -------------------------
195
+ # Step 5: Main Execution
196
+ # -------------------------
197
+ def main(args):
198
+ labels = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE"]
199
+ label2id = {l: i for i, l in enumerate(labels)}
200
+ id2label = {i: l for l, i in label2id.items()}
201
+
202
+ TEMP_DIR = "temp_intermediate_files"
203
+ os.makedirs(TEMP_DIR, exist_ok=True)
204
+
205
+ # 1. Preprocess & Augment
206
+ initial_json = os.path.join(TEMP_DIR, "data_bio.json")
207
+ preprocess_labelstudio(args.input, initial_json)
208
+ augmented_json = os.path.join(TEMP_DIR, "data_aug.json")
209
+ final_data_path = augment_and_save_dataset(initial_json, augmented_json)
210
+
211
+ # 2. Setup Data
212
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
213
+ dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
214
+ val_size = int(0.2 * len(dataset))
215
+ train_dataset, val_dataset = random_split(dataset, [len(dataset) - val_size, val_size])
216
+
217
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
218
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
219
+
220
+ # 3. Model
221
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
222
+ model = LayoutLMv3CRF(num_labels=len(labels)).to(device)
223
+ optimizer = AdamW(model.parameters(), lr=args.lr)
224
+
225
+ # 4. Loop
226
+ for epoch in range(args.epochs):
227
+ loss = train_one_epoch(model, train_loader, optimizer, device)
228
+ p, r, f1 = evaluate(model, val_loader, device, id2label)
229
+ print(f"Epoch {epoch+1} | Loss: {loss:.4f} | F1: {f1:.3f}")
230
+
231
+ ckpt_path = "checkpoints/layoutlmv3_nonlinear_scratch.pth"
232
+ os.makedirs("checkpoints", exist_ok=True)
233
+ torch.save(model.state_dict(), ckpt_path)
234
+
235
+ if __name__ == "__main__":
236
+ parser = argparse.ArgumentParser()
237
+ parser.add_argument("--mode", type=str, default="train")
238
+ parser.add_argument("--input", type=str, required=True)
239
+ parser.add_argument("--batch_size", type=int, default=4)
240
+ parser.add_argument("--epochs", type=int, default=10) # Increased for scratch training
241
+ parser.add_argument("--lr", type=float, default=2e-5)
242
+ parser.add_argument("--max_len", type=int, default=512)
243
+ args = parser.parse_args()
244
+ main(args)