Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -71,7 +71,8 @@ def load_document_context(task_id):
|
|
| 71 |
def fine_tune_cuad_model():
|
| 72 |
"""
|
| 73 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
| 74 |
-
|
|
|
|
| 75 |
"""
|
| 76 |
from datasets import load_dataset
|
| 77 |
import numpy as np
|
|
@@ -81,9 +82,11 @@ def fine_tune_cuad_model():
|
|
| 81 |
dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
|
| 82 |
|
| 83 |
if "train" in dataset:
|
| 84 |
-
|
|
|
|
| 85 |
if "validation" in dataset:
|
| 86 |
-
|
|
|
|
| 87 |
else:
|
| 88 |
split = train_dataset.train_test_split(test_size=0.2)
|
| 89 |
train_dataset = split["train"]
|
|
@@ -148,17 +151,18 @@ def fine_tune_cuad_model():
|
|
| 148 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
| 149 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
| 150 |
|
|
|
|
| 151 |
training_args = TrainingArguments(
|
| 152 |
output_dir="./fine_tuned_legal_qa",
|
| 153 |
evaluation_strategy="steps",
|
| 154 |
-
eval_steps=
|
| 155 |
learning_rate=2e-5,
|
| 156 |
-
per_device_train_batch_size=
|
| 157 |
-
per_device_eval_batch_size=
|
| 158 |
-
num_train_epochs=1,
|
| 159 |
weight_decay=0.01,
|
| 160 |
-
logging_steps=
|
| 161 |
-
save_steps=
|
| 162 |
load_best_model_at_end=True,
|
| 163 |
report_to=[] # Disable wandb logging
|
| 164 |
)
|
|
@@ -737,3 +741,4 @@ if __name__ == "__main__":
|
|
| 737 |
else:
|
| 738 |
print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
|
| 739 |
run()
|
|
|
|
|
|
| 71 |
def fine_tune_cuad_model():
|
| 72 |
"""
|
| 73 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
| 74 |
+
For testing, we use only 50 training examples (and 10 for validation)
|
| 75 |
+
and set training arguments for very fast, minimal training.
|
| 76 |
"""
|
| 77 |
from datasets import load_dataset
|
| 78 |
import numpy as np
|
|
|
|
| 82 |
dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
|
| 83 |
|
| 84 |
if "train" in dataset:
|
| 85 |
+
# Use only 50 examples for training
|
| 86 |
+
train_dataset = dataset["train"].select(range(50))
|
| 87 |
if "validation" in dataset:
|
| 88 |
+
# Use 10 examples for validation
|
| 89 |
+
val_dataset = dataset["validation"].select(range(10))
|
| 90 |
else:
|
| 91 |
split = train_dataset.train_test_split(test_size=0.2)
|
| 92 |
train_dataset = split["train"]
|
|
|
|
| 151 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
| 152 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
| 153 |
|
| 154 |
+
# Adjust training arguments for fast testing
|
| 155 |
training_args = TrainingArguments(
|
| 156 |
output_dir="./fine_tuned_legal_qa",
|
| 157 |
evaluation_strategy="steps",
|
| 158 |
+
eval_steps=10,
|
| 159 |
learning_rate=2e-5,
|
| 160 |
+
per_device_train_batch_size=4,
|
| 161 |
+
per_device_eval_batch_size=4,
|
| 162 |
+
num_train_epochs=0.1, # Very short training for testing purposes
|
| 163 |
weight_decay=0.01,
|
| 164 |
+
logging_steps=5,
|
| 165 |
+
save_steps=10,
|
| 166 |
load_best_model_at_end=True,
|
| 167 |
report_to=[] # Disable wandb logging
|
| 168 |
)
|
|
|
|
| 741 |
else:
|
| 742 |
print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
|
| 743 |
run()
|
| 744 |
+
|