Spaces:
Running
Running
| import pandas as pd | |
| import torch | |
| import numpy as np | |
| import os | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| Trainer, | |
| TrainingArguments, | |
| EarlyStoppingCallback | |
| ) | |
| from torch import nn | |
| # 1. Config | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DATA_PATH = os.path.join(BASE_DIR, '../data/reddit_disaster_posts.csv') | |
| MODEL_OUTPUT_DIR = os.path.join(BASE_DIR, 'models/roberta_model') | |
| # --- THE UPGRADE: Multilingual Brain (English + Tagalog) --- | |
| MODEL_NAME = 'xlm-roberta-base' | |
| print(f"--- ALISTO: Training Multilingual Brain ({MODEL_NAME}) ---") | |
| # 2. Load Data | |
| if not os.path.exists(DATA_PATH): | |
| print("β Error: CSV file not found. Run augment_data.py first!") | |
| exit() | |
| df = pd.read_csv(DATA_PATH) | |
| df = df.dropna(subset=['text', 'label']) | |
| texts = df['text'].tolist() | |
| labels = df['label'].tolist() | |
| print(f"Loaded {len(df)} samples.") | |
| # 3. Split (80% Train, 20% Validation) | |
| train_texts, val_texts, train_labels, val_labels = train_test_split( | |
| texts, labels, test_size=0.2, random_state=42, stratify=labels | |
| ) | |
| # 4. Tokenize | |
| print(f"Downloading tokenizer for {MODEL_NAME}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| def tokenize_function(texts): | |
| return tokenizer(texts, padding=True, truncation=True, max_length=128) | |
| train_encodings = tokenize_function(train_texts) | |
| val_encodings = tokenize_function(val_texts) | |
| # 5. Dataset Class | |
| class DisasterDataset(torch.utils.data.Dataset): | |
| def __init__(self, encodings, labels): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
| item['labels'] = torch.tensor(self.labels[idx]) | |
| return item | |
| def __len__(self): | |
| return len(self.labels) | |
| train_dataset = DisasterDataset(train_encodings, train_labels) | |
| val_dataset = DisasterDataset(val_encodings, val_labels) | |
| # --- CUSTOM TRAINER WITH WEIGHTED LOSS --- | |
| # Punishes the model 3x more if it misses a Rescue Request (False Negative) | |
| class WeightedTrainer(Trainer): | |
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs): | |
| labels = inputs.get("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.get("logits") | |
| # [1.0, 3.0] -> Label 1 is 3x more important than Label 0 | |
| loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 3.0]).to(model.device)) | |
| loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| # Metrics | |
| def compute_metrics(pred): | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary') | |
| acc = accuracy_score(labels, preds) | |
| return { | |
| 'accuracy': acc, | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall | |
| } | |
| # 6. Model Initialization | |
| print(f"Downloading base model {MODEL_NAME}...") | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2) | |
| # 7. Training Args | |
| training_args = TrainingArguments( | |
| output_dir='./results', | |
| num_train_epochs=15, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| warmup_steps=500, | |
| weight_decay=0.01, | |
| learning_rate=2e-5, | |
| logging_dir='./logs', | |
| logging_steps=50, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1", | |
| seed=42 | |
| ) | |
| # 8. Train | |
| trainer = WeightedTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] | |
| ) | |
| print("Starting training (XLM-R + Weighted Loss)...") | |
| trainer.train() | |
| # 9. Save | |
| print(f"Saving upgraded model to {MODEL_OUTPUT_DIR}...") | |
| model.save_pretrained(MODEL_OUTPUT_DIR) | |
| tokenizer.save_pretrained(MODEL_OUTPUT_DIR) | |
| print("β Multilingual Brain Training Complete.") |