Quivara's picture
Fresh upload with LFS
bdb271a
import pandas as pd
import torch
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
from torch import nn
# 1. Config
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = os.path.join(BASE_DIR, '../data/reddit_disaster_posts.csv')
MODEL_OUTPUT_DIR = os.path.join(BASE_DIR, 'models/roberta_model')
# --- THE UPGRADE: Multilingual Brain (English + Tagalog) ---
MODEL_NAME = 'xlm-roberta-base'
print(f"--- ALISTO: Training Multilingual Brain ({MODEL_NAME}) ---")
# 2. Load Data
if not os.path.exists(DATA_PATH):
print("❌ Error: CSV file not found. Run augment_data.py first!")
exit()
df = pd.read_csv(DATA_PATH)
df = df.dropna(subset=['text', 'label'])
texts = df['text'].tolist()
labels = df['label'].tolist()
print(f"Loaded {len(df)} samples.")
# 3. Split (80% Train, 20% Validation)
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=0.2, random_state=42, stratify=labels
)
# 4. Tokenize
print(f"Downloading tokenizer for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def tokenize_function(texts):
return tokenizer(texts, padding=True, truncation=True, max_length=128)
train_encodings = tokenize_function(train_texts)
val_encodings = tokenize_function(val_texts)
# 5. Dataset Class
class DisasterDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = DisasterDataset(train_encodings, train_labels)
val_dataset = DisasterDataset(val_encodings, val_labels)
# --- CUSTOM TRAINER WITH WEIGHTED LOSS ---
# Punishes the model 3x more if it misses a Rescue Request (False Negative)
class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
# [1.0, 3.0] -> Label 1 is 3x more important than Label 0
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 3.0]).to(model.device))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
# Metrics
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall
}
# 6. Model Initialization
print(f"Downloading base model {MODEL_NAME}...")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
# 7. Training Args
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=15,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
learning_rate=2e-5,
logging_dir='./logs',
logging_steps=50,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
seed=42
)
# 8. Train
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)
print("Starting training (XLM-R + Weighted Loss)...")
trainer.train()
# 9. Save
print(f"Saving upgraded model to {MODEL_OUTPUT_DIR}...")
model.save_pretrained(MODEL_OUTPUT_DIR)
tokenizer.save_pretrained(MODEL_OUTPUT_DIR)
print("✅ Multilingual Brain Training Complete.")