heerjtdev commited on
Commit
c7a125e
·
verified ·
1 Parent(s): 98e73eb

Rename train.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +264 -0
  2. train.py +0 -244
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+ import shutil
7
+
8
+ # --- CONFIGURATION UPDATED FOR HYBRID MODEL ---
9
+ TRAINING_SCRIPT = "train_hybrid.py"
10
+ MODEL_OUTPUT_DIR = "checkpoints"
11
+ MODEL_FILE_NAME = "layoutlmv3_bilstm_crf_hybrid.pth"
12
+ MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
13
+
14
+ # ----------------------------------------------------------------
15
+
16
+ def retrieve_model():
17
+ """
18
+ Checks for the final model file and prepares it for download.
19
+ Useful for when the training job finishes server-side but the
20
+ client connection has timed out.
21
+ """
22
+ if os.path.exists(MODEL_FILE_PATH):
23
+ file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
24
+
25
+ # Copy to a simple location that Gradio can reliably serve
26
+ import tempfile
27
+ temp_dir = tempfile.gettempdir()
28
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
29
+ temp_model_path = os.path.join(temp_dir, f"hybrid_model_recovered_{timestamp}.pth")
30
+
31
+ try:
32
+ shutil.copy2(MODEL_FILE_PATH, temp_model_path)
33
+ download_path = temp_model_path
34
+
35
+ log_output = (
36
+ f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
37
+ f"🎉 SUCCESS! The Hybrid LayoutLMv3+BiLSTM+CRF model was found.\n"
38
+ f"📦 Model file: {MODEL_FILE_PATH}\n"
39
+ f"📊 Model size: {file_size:.2f} MB\n"
40
+ f"🔗 Download path prepared: {download_path}\n\n"
41
+ f"⬇️ Click the '📥 Download Model' button below to save your model."
42
+ )
43
+ return log_output, download_path, gr.Button(visible=True)
44
+
45
+ except Exception as e:
46
+ log_output = (
47
+ f"--- Model Status Check FAILED ---\n"
48
+ f"⚠️ Trained model found, but could not prepare for download: {e}\n"
49
+ f"📁 Original Path: {MODEL_FILE_PATH}. Try again or check Space logs."
50
+ )
51
+ return log_output, None, gr.Button(visible=False)
52
+
53
+ else:
54
+ log_output = (
55
+ f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
56
+ f"❌ Model file not found at {MODEL_FILE_PATH}.\n"
57
+ f"Training may still be running or it failed. Check back later."
58
+ )
59
+ return log_output, None, gr.Button(visible=False)
60
+
61
+
62
+ def clear_memory(dataset_file: gr.File):
63
+ """
64
+ Deletes the model output directory and the uploaded dataset file.
65
+ """
66
+ log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
67
+
68
+ # 1. Clear Model Checkpoints Directory
69
+ if os.path.exists(MODEL_OUTPUT_DIR):
70
+ try:
71
+ shutil.rmtree(MODEL_OUTPUT_DIR)
72
+ log_output += f"✅ Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n"
73
+ except Exception as e:
74
+ log_output += f"❌ ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n"
75
+ else:
76
+ log_output += f"ℹ️ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n"
77
+
78
+ # 2. Clear Uploaded Dataset File (Temporary file cleanup)
79
+ if dataset_file is not None:
80
+ input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
81
+ if os.path.exists(input_path):
82
+ try:
83
+ os.remove(input_path)
84
+ log_output += f"✅ Successfully deleted uploaded dataset file: {input_path}\n"
85
+ except Exception as e:
86
+ log_output += f"❌ ERROR deleting dataset file {input_path}: {e}\n"
87
+ else:
88
+ log_output += f"ℹ️ Uploaded dataset file not found at {input_path}.\n"
89
+ else:
90
+ log_output += f"ℹ️ No dataset file currently tracked for deletion.\n"
91
+
92
+ log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
93
+ log_output += "✨ Files and checkpoints have been removed. You can now start a fresh training run."
94
+
95
+ return log_output, None, gr.Button(visible=False), None
96
+
97
+
98
+ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
99
+ """
100
+ Handles the Gradio submission and executes the training script using subprocess.
101
+ """
102
+ # 1. Setup: Create output directory
103
+ os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
104
+
105
+ # 2. File Handling
106
+ if dataset_file is None:
107
+ yield "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
108
+ return
109
+
110
+ input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
111
+
112
+ if not os.path.exists(input_path):
113
+ yield f"❌ ERROR: Uploaded file not found at {input_path}.", None, gr.Button(visible=False)
114
+ return
115
+
116
+ progress(0.1, desc="Initializing Hybrid Model Training...")
117
+
118
+ log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
119
+ log_output += f"🤖 Architecture: LayoutLMv3 + BiLSTM + CRF\n"
120
+
121
+ # 3. Construct the subprocess command
122
+ command = [
123
+ sys.executable,
124
+ TRAINING_SCRIPT,
125
+ "--mode", "train",
126
+ "--input", input_path,
127
+ "--batch_size", str(batch_size),
128
+ "--epochs", str(epochs),
129
+ "--lr", str(lr),
130
+ "--max_len", str(max_len)
131
+ ]
132
+
133
+ log_output += f"Executing command: {' '.join(command)}\n\n"
134
+ yield log_output, None, gr.Button(visible=False)
135
+
136
+ try:
137
+ # 4. Run the training script
138
+ process = subprocess.Popen(
139
+ command,
140
+ stdout=subprocess.PIPE,
141
+ stderr=subprocess.STDOUT,
142
+ text=True,
143
+ bufsize=1
144
+ )
145
+
146
+ # Stream logs
147
+ for line in iter(process.stdout.readline, ""):
148
+ log_output += line
149
+ print(line, end='')
150
+ yield log_output, None, gr.Button(visible=False)
151
+
152
+ process.stdout.close()
153
+ return_code = process.wait()
154
+
155
+ # 5. Check completion
156
+ if return_code == 0:
157
+ log_output += "\n" + "=" * 60 + "\n"
158
+ log_output += "✅ HYBRID TRAINING COMPLETE!\n"
159
+ log_output += "=" * 60 + "\n"
160
+
161
+ if os.path.exists(MODEL_FILE_PATH):
162
+ file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024)
163
+ log_output += f"\n📦 Model file found: {MODEL_FILE_PATH} ({file_size:.2f} MB)"
164
+
165
+ # Copy for download
166
+ import tempfile
167
+ temp_dir = tempfile.gettempdir()
168
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
169
+ temp_model_path = os.path.join(temp_dir, f"hybrid_model_{timestamp}.pth")
170
+
171
+ try:
172
+ shutil.copy2(MODEL_FILE_PATH, temp_model_path)
173
+ download_path = temp_model_path
174
+ except Exception as e:
175
+ log_output += f"\n⚠️ Copy failed: {e}, using original path"
176
+ download_path = MODEL_FILE_PATH
177
+
178
+ log_output += f"\n\n⬇️ Click the '📥 Download Model' button below."
179
+ yield log_output, download_path, gr.Button(visible=True)
180
+ return
181
+ else:
182
+ log_output += f"\n❌ Error: Training finished but {MODEL_FILE_PATH} was not found."
183
+ yield log_output, None, gr.Button(visible=False)
184
+ return
185
+ else:
186
+ log_output += f"\n❌ TRAINING FAILED with return code {return_code}\n"
187
+ yield log_output, None, gr.Button(visible=False)
188
+ return
189
+
190
+ except FileNotFoundError:
191
+ yield log_output + f"\n❌ ERROR: '{TRAINING_SCRIPT}' not found.", None, gr.Button(visible=False)
192
+ except Exception as e:
193
+ yield log_output + f"\n❌ Unexpected Error: {e}", None, gr.Button(visible=False)
194
+
195
+
196
+ # --- Gradio Interface Setup ---
197
+ with gr.Blocks(title="Hybrid LayoutLM Training", theme=gr.themes.Soft()) as demo:
198
+ gr.Markdown("# 🧬 Hybrid LayoutLMv3 + BiLSTM + CRF Training")
199
+ gr.Markdown(
200
+ """
201
+ **Architecture:** This app trains a state-of-the-art stack:
202
+ 1. **LayoutLMv3** (Visual & Textual Embeddings)
203
+ 2. **Bi-LSTM** (Sequence Context Modeling)
204
+ 3. **CRF** (Label Consistency Enforcement)
205
+
206
+ **Instructions:** Upload your Label Studio JSON, set parameters, and train.
207
+ **Note:** This model is slower to train than standard LayoutLM but typically achieves higher accuracy on complex layouts.
208
+ """
209
+ )
210
+
211
+ with gr.Row():
212
+ with gr.Column(scale=1):
213
+ gr.Markdown("### 📁 Dataset")
214
+ file_input = gr.File(label="Upload Label Studio JSON", file_types=[".json"])
215
+
216
+ gr.Markdown("### ⚙️ Hyperparameters")
217
+ batch_size_input = gr.Slider(1, 16, value=4, step=1, label="Batch Size")
218
+ epochs_input = gr.Slider(1, 10, value=5, step=1, label="Epochs")
219
+ lr_input = gr.Number(value=2e-5, label="Learning Rate (Backbone)", info="LSTM/CRF head uses 1e-4")
220
+ max_len_input = gr.Slider(128, 512, value=512, step=128, label="Max Seq Len")
221
+
222
+ train_button = gr.Button("🔥 Start Hybrid Training", variant="primary", size="lg")
223
+ check_button = gr.Button("🔍 Check Status / Recover Model", variant="secondary")
224
+ clear_button = gr.Button("🧹 Clear Files", variant="stop")
225
+
226
+ with gr.Column(scale=2):
227
+ log_output = gr.Textbox(
228
+ label="Training Logs", lines=25, autoscroll=True, show_copy_button=True,
229
+ placeholder="Logs will appear here..."
230
+ )
231
+
232
+ download_btn = gr.Button("📥 Download Hybrid Model", variant="primary", size="lg", visible=False)
233
+
234
+ # State and hidden download component
235
+ model_path_state = gr.State(value=None)
236
+ model_download = gr.File(label="Download", interactive=False, visible=True)
237
+
238
+ # Actions
239
+ train_button.click(
240
+ fn=train_model,
241
+ inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
242
+ outputs=[log_output, model_path_state, download_btn]
243
+ )
244
+
245
+ check_button.click(
246
+ fn=retrieve_model,
247
+ inputs=[],
248
+ outputs=[log_output, model_path_state, download_btn]
249
+ )
250
+
251
+ clear_button.click(
252
+ fn=clear_memory,
253
+ inputs=[file_input],
254
+ outputs=[log_output, model_path_state, download_btn, model_download]
255
+ )
256
+
257
+ download_btn.click(
258
+ fn=lambda path: path,
259
+ inputs=[model_path_state],
260
+ outputs=[model_download]
261
+ )
262
+
263
+ if __name__ == "__main__":
264
+ demo.launch()
train.py DELETED
@@ -1,244 +0,0 @@
1
- import json
2
- import argparse
3
- import os
4
- import random
5
- import torch
6
- import torch.nn as nn
7
- from torch.utils.data import Dataset, DataLoader, random_split
8
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
9
- from TorchCRF import CRF
10
- from torch.optim import AdamW
11
- from tqdm import tqdm
12
- from sklearn.metrics import precision_recall_fscore_support
13
-
14
- # --- Configuration ---
15
- MAX_BBOX_DIMENSION = 1000
16
- MAX_SHIFT = 30
17
- AUGMENTATION_FACTOR = 1
18
- BASE_MODEL_ID = "heerjtdev/MLP_LayoutLM"
19
-
20
- # -------------------------
21
- # Step 1: Preprocessing
22
- # -------------------------
23
- def preprocess_labelstudio(input_path, output_path):
24
- with open(input_path, "r", encoding="utf-8") as f:
25
- data = json.load(f)
26
-
27
- processed = []
28
- print(f"🔄 Starting preprocessing of {len(data)} documents...")
29
-
30
- for item in data:
31
- words = item["data"]["original_words"]
32
- bboxes = item["data"]["original_bboxes"]
33
- labels = ["O"] * len(words)
34
-
35
- clamped_bboxes = []
36
- for bbox in bboxes:
37
- x_min, y_min, x_max, y_max = bbox
38
- new_x_min = max(0, min(x_min, 1000))
39
- new_y_min = max(0, min(y_min, 1000))
40
- new_x_max = max(0, min(x_max, 1000))
41
- new_y_max = max(0, min(y_max, 1000))
42
- if new_x_min > new_x_max: new_x_min = new_x_max
43
- if new_y_min > new_y_max: new_y_min = new_y_max
44
- clamped_bboxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
45
-
46
- if "annotations" in item:
47
- for ann in item["annotations"]:
48
- for res in ann["result"]:
49
- if "value" in res and "labels" in res["value"]:
50
- text = res["value"]["text"]
51
- tag = res["value"]["labels"][0]
52
- text_tokens = text.split()
53
- for i in range(len(words) - len(text_tokens) + 1):
54
- if words[i:i + len(text_tokens)] == text_tokens:
55
- labels[i] = f"B-{tag}"
56
- for j in range(1, len(text_tokens)):
57
- labels[i + j] = f"I-{tag}"
58
- break
59
-
60
- processed.append({"tokens": words, "labels": labels, "bboxes": clamped_bboxes})
61
-
62
- with open(output_path, "w", encoding="utf-8") as f:
63
- json.dump(processed, f, indent=2, ensure_ascii=False)
64
- return output_path
65
-
66
- # -------------------------
67
- # Step 1.5: Augmentation
68
- # -------------------------
69
- def translate_bbox(bbox, shift_x, shift_y):
70
- x_min, y_min, x_max, y_max = bbox
71
- new_x_min = max(0, min(x_min + shift_x, 1000))
72
- new_y_min = max(0, min(y_min + shift_y, 1000))
73
- new_x_max = max(0, min(x_max + shift_x, 1000))
74
- new_y_max = max(0, min(y_max + shift_y, 1000))
75
- return [new_x_min, new_y_min, new_x_max, new_y_max]
76
-
77
- def augment_sample(sample):
78
- shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
79
- shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
80
- new_sample = sample.copy()
81
- new_sample["bboxes"] = [translate_bbox(b, shift_x, shift_y) for b in sample["bboxes"]]
82
- return new_sample
83
-
84
- def augment_and_save_dataset(input_json_path, output_json_path):
85
- with open(input_json_path, 'r', encoding="utf-8") as f:
86
- training_data = json.load(f)
87
- augmented_data = []
88
- for original_sample in training_data:
89
- augmented_data.append(original_sample)
90
- for _ in range(AUGMENTATION_FACTOR):
91
- augmented_data.append(augment_sample(original_sample))
92
- with open(output_json_path, 'w', encoding="utf-8") as f:
93
- json.dump(augmented_data, f, indent=2, ensure_ascii=False)
94
- return output_json_path
95
-
96
- # -------------------------
97
- # Step 2: Dataset Class
98
- # -------------------------
99
- class LayoutDataset(Dataset):
100
- def __init__(self, json_path, tokenizer, label2id, max_len=512):
101
- with open(json_path, "r", encoding="utf-8") as f:
102
- self.data = json.load(f)
103
- self.tokenizer = tokenizer
104
- self.label2id = label2id
105
- self.max_len = max_len
106
-
107
- def __len__(self):
108
- return len(self.data)
109
-
110
- def __getitem__(self, idx):
111
- item = self.data[idx]
112
- words, bboxes, labels = item["tokens"], item["bboxes"], item["labels"]
113
- encodings = self.tokenizer(words, boxes=bboxes, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
114
- word_ids = encodings.word_ids(batch_index=0)
115
- label_ids = []
116
- for word_id in word_ids:
117
- if word_id is None:
118
- label_ids.append(self.label2id["O"])
119
- else:
120
- label_ids.append(self.label2id.get(labels[word_id], self.label2id["O"]))
121
- encodings["labels"] = torch.tensor(label_ids)
122
- return {key: val.squeeze(0) for key, val in encodings.items()}
123
-
124
- # -------------------------
125
- # Step 3: Model Architecture (Non-Linear Head)
126
- # -------------------------
127
-
128
- class LayoutLMv3CRF(nn.Module):
129
- def __init__(self, num_labels):
130
- super().__init__()
131
- # Initializing from scratch (Base weights only)
132
- print(f"🔄 Initializing backbone from {BASE_MODEL_ID}...")
133
- self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
134
-
135
- hidden_size = self.layoutlm.config.hidden_size
136
-
137
- # NON-LINEAR MLP HEAD
138
- # Replacing the simple Linear layer with a deeper architecture
139
- self.classifier = nn.Sequential(
140
- nn.Linear(hidden_size, hidden_size),
141
- nn.GELU(), # Non-linear activation
142
- nn.LayerNorm(hidden_size), # Stability for training from scratch
143
- nn.Dropout(0.1),
144
- nn.Linear(hidden_size, num_labels)
145
- )
146
-
147
- self.crf = CRF(num_labels)
148
-
149
- def forward(self, input_ids, bbox, attention_mask, labels=None):
150
- outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
151
- sequence_output = outputs.last_hidden_state
152
-
153
- # Pass through the new non-linear head
154
- emissions = self.classifier(sequence_output)
155
-
156
- if labels is not None:
157
- log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
158
- return -log_likelihood.mean()
159
- else:
160
- return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
161
-
162
- # -------------------------
163
- # Step 4: Training + Evaluation
164
- # -------------------------
165
- def train_one_epoch(model, dataloader, optimizer, device):
166
- model.train()
167
- total_loss = 0
168
- for batch in tqdm(dataloader, desc="Training"):
169
- batch = {k: v.to(device) for k, v in batch.items()}
170
- labels = batch.pop("labels")
171
- optimizer.zero_grad()
172
- loss = model(**batch, labels=labels)
173
- loss.backward()
174
- optimizer.step()
175
- total_loss += loss.item()
176
- return total_loss / len(dataloader)
177
-
178
- def evaluate(model, dataloader, device, id2label):
179
- model.eval()
180
- all_preds, all_labels = [], []
181
- with torch.no_grad():
182
- for batch in tqdm(dataloader, desc="Evaluating"):
183
- batch = {k: v.to(device) for k, v in batch.items()}
184
- labels = batch.pop("labels").cpu().numpy()
185
- preds = model(**batch)
186
- for p, l, mask in zip(preds, labels, batch["attention_mask"].cpu().numpy()):
187
- valid = mask == 1
188
- l_valid = l[valid].tolist()
189
- all_labels.extend(l_valid)
190
- all_preds.extend(p[:len(l_valid)])
191
- precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
192
- return precision, recall, f1
193
-
194
- # -------------------------
195
- # Step 5: Main Execution
196
- # -------------------------
197
- def main(args):
198
- labels = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE"]
199
- label2id = {l: i for i, l in enumerate(labels)}
200
- id2label = {i: l for l, i in label2id.items()}
201
-
202
- TEMP_DIR = "temp_intermediate_files"
203
- os.makedirs(TEMP_DIR, exist_ok=True)
204
-
205
- # 1. Preprocess & Augment
206
- initial_json = os.path.join(TEMP_DIR, "data_bio.json")
207
- preprocess_labelstudio(args.input, initial_json)
208
- augmented_json = os.path.join(TEMP_DIR, "data_aug.json")
209
- final_data_path = augment_and_save_dataset(initial_json, augmented_json)
210
-
211
- # 2. Setup Data
212
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
213
- dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
214
- val_size = int(0.2 * len(dataset))
215
- train_dataset, val_dataset = random_split(dataset, [len(dataset) - val_size, val_size])
216
-
217
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
218
- val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
219
-
220
- # 3. Model
221
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
222
- model = LayoutLMv3CRF(num_labels=len(labels)).to(device)
223
- optimizer = AdamW(model.parameters(), lr=args.lr)
224
-
225
- # 4. Loop
226
- for epoch in range(args.epochs):
227
- loss = train_one_epoch(model, train_loader, optimizer, device)
228
- p, r, f1 = evaluate(model, val_loader, device, id2label)
229
- print(f"Epoch {epoch+1} | Loss: {loss:.4f} | F1: {f1:.3f}")
230
-
231
- ckpt_path = "checkpoints/layoutlmv3_nonlinear_scratch.pth"
232
- os.makedirs("checkpoints", exist_ok=True)
233
- torch.save(model.state_dict(), ckpt_path)
234
-
235
- if __name__ == "__main__":
236
- parser = argparse.ArgumentParser()
237
- parser.add_argument("--mode", type=str, default="train")
238
- parser.add_argument("--input", type=str, required=True)
239
- parser.add_argument("--batch_size", type=int, default=4)
240
- parser.add_argument("--epochs", type=int, default=10) # Increased for scratch training
241
- parser.add_argument("--lr", type=float, default=2e-5)
242
- parser.add_argument("--max_len", type=int, default=512)
243
- args = parser.parse_args()
244
- main(args)