Spaces:
Sleeping
Sleeping
| import os | |
| from datasets import load_dataset | |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from sklearn.metrics import accuracy_score, f1_score | |
| import torch | |
| import numpy as np | |
| from transformers import set_seed | |
| from transformers import DataCollatorWithPadding | |
| from datetime import datetime | |
| from src.config import HF_ENDPOINT, HF_TOKEN, MODEL_DIR | |
| if HF_ENDPOINT: | |
| os.environ["HF_ENDPOINT"] = HF_ENDPOINT | |
| set_seed(42) | |
| np.random.seed(42) | |
| torch.manual_seed(42) | |
| output_dir=os.path.join(MODEL_DIR, "bias_checkpoints") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| def main(): | |
| dataset = load_dataset("mediabiasgroup/BABE", token=HF_TOKEN) | |
| dataset = dataset["train"].train_test_split(test_size=0.2, seed=42) | |
| model_name = "roberta-base" | |
| tokenizer = RobertaTokenizer.from_pretrained(model_name, token=HF_TOKEN) | |
| def collapse_labels(example): | |
| old = example["label"] | |
| if old in [0, 1]: | |
| example["label"] = 0 # Right | |
| elif old == 2: | |
| example["label"] = 1 # Center | |
| else: | |
| example["label"] = 2 # Left | |
| return example | |
| #dataset = dataset.map(collapse_labels) | |
| def tokenize_function(examples): | |
| return tokenizer(examples["text"], truncation=True, max_length=128) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| tokenized_datasets = dataset.map(tokenize_function, batched = True) | |
| tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
| tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) | |
| model = RobertaForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=2, | |
| token=HF_TOKEN, | |
| ) | |
| peft_config = LoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| r=8, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| target_modules=["query", "value"] | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| learning_rate=2e-4, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| num_train_epochs=3, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| logging_steps=10, | |
| report_to="none" | |
| ) | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=1) | |
| return { | |
| "accuracy": accuracy_score(labels, preds), | |
| "f1_weighted": f1_score(labels, preds, average="weighted") | |
| } | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_datasets["train"], | |
| eval_dataset=tokenized_datasets["test"], | |
| compute_metrics=compute_metrics, | |
| data_collator=data_collator | |
| ) | |
| trainer.train() | |
| # CRITICAL FIX | |
| model = model.merge_and_unload() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = os.path.join(MODEL_DIR, f"bias_lora_{timestamp}") | |
| model.save_pretrained(save_path) | |
| tokenizer.save_pretrained(save_path) | |
| if __name__ == "__main__": | |
| main() | |