Update train.py
Browse files
train.py
CHANGED
|
@@ -3,10 +3,8 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trai
|
|
| 3 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 4 |
import os
|
| 5 |
|
| 6 |
-
# Load
|
| 7 |
dataset = load_dataset("glue", "sst2")
|
| 8 |
-
|
| 9 |
-
# Use a small subset to stay within 25-minute budget
|
| 10 |
small_train = dataset["train"].select(range(500))
|
| 11 |
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
|
@@ -16,15 +14,23 @@ def tokenize_fn(batch):
|
|
| 16 |
|
| 17 |
tokenized_train = small_train.map(tokenize_fn, batched=True)
|
| 18 |
|
| 19 |
-
# Load model
|
| 20 |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
model = get_peft_model(model, peft_config)
|
| 23 |
|
| 24 |
-
# Hugging Face token
|
| 25 |
-
hf_token = os.environ.get("HF_TOKEN") or "hf_xxx" #
|
| 26 |
|
| 27 |
-
# Training arguments
|
| 28 |
training_args = TrainingArguments(
|
| 29 |
output_dir="results",
|
| 30 |
per_device_train_batch_size=8,
|
|
@@ -34,7 +40,7 @@ training_args = TrainingArguments(
|
|
| 34 |
save_strategy="epoch",
|
| 35 |
push_to_hub=True,
|
| 36 |
hub_model_id="NightPrince/peft-distilbert-sst2",
|
| 37 |
-
hub_token=hf_token
|
| 38 |
)
|
| 39 |
|
| 40 |
trainer = Trainer(
|
|
|
|
| 3 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 4 |
import os
|
| 5 |
|
| 6 |
+
# Load SST-2 dataset (sentiment classification) and take a small subset for fast training
|
| 7 |
dataset = load_dataset("glue", "sst2")
|
|
|
|
|
|
|
| 8 |
small_train = dataset["train"].select(range(500))
|
| 9 |
|
| 10 |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
|
|
|
| 14 |
|
| 15 |
tokenized_train = small_train.map(tokenize_fn, batched=True)
|
| 16 |
|
| 17 |
+
# Load model
|
| 18 |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
|
| 19 |
+
|
| 20 |
+
# Apply PEFT with LoRA — FIXED: target_modules is now set for DistilBERT
|
| 21 |
+
peft_config = LoraConfig(
|
| 22 |
+
task_type=TaskType.SEQ_CLS,
|
| 23 |
+
inference_mode=False,
|
| 24 |
+
r=8,
|
| 25 |
+
lora_alpha=32,
|
| 26 |
+
lora_dropout=0.1,
|
| 27 |
+
target_modules=["q_lin", "v_lin"] # Required for DistilBERT
|
| 28 |
+
)
|
| 29 |
model = get_peft_model(model, peft_config)
|
| 30 |
|
| 31 |
+
# Hugging Face token (set as a Secret in Space settings)
|
| 32 |
+
hf_token = os.environ.get("HF_TOKEN") or "hf_xxx" # Replace if needed
|
| 33 |
|
|
|
|
| 34 |
training_args = TrainingArguments(
|
| 35 |
output_dir="results",
|
| 36 |
per_device_train_batch_size=8,
|
|
|
|
| 40 |
save_strategy="epoch",
|
| 41 |
push_to_hub=True,
|
| 42 |
hub_model_id="NightPrince/peft-distilbert-sst2",
|
| 43 |
+
hub_token=hf_token
|
| 44 |
)
|
| 45 |
|
| 46 |
trainer = Trainer(
|