lite_DETECTIVE / cold /classifier.py
AlbertCAC's picture
update
ce367e1
raw
history blame
10.7 kB
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from torch.optim import AdamW, lr_scheduler
from .text_cnn import DynamicTextCNN
from tqdm import tqdm
import os
class ToxicTextClassifier(nn.Module):
def __init__(self,
bert_name='hfl/chinese-roberta-wwm-ext',
num_filters=1536,
filter_sizes=(1,2,3,4),
K=4,
fc_dim=128,
num_classes=2,
dropout=0.1,
name='lite'):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(bert_name,from_tf=True)
self.bert = BertModel.from_pretrained(bert_name)
self.name = name
self.unfrozen_layers = 0
hidden_size = self.bert.config.hidden_size * 2
os.makedirs(f'data/{name}', exist_ok=True)
self.text_cnn = DynamicTextCNN(hidden_size, num_filters, filter_sizes, K, dropout)
input_dim = len(filter_sizes) * num_filters
self.classifier = nn.Sequential(
nn.Linear(input_dim, fc_dim),
nn.ReLU(),
nn.LayerNorm(fc_dim),
nn.Dropout(dropout),
nn.Linear(fc_dim, fc_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(fc_dim // 2, num_classes)
)
self.criterion = nn.CrossEntropyLoss()
self._rebuild_optimizer()
def _rebuild_optimizer(self):
param_groups = [
{'params': self.text_cnn.parameters(), 'lr': 1e-4},
{'params': self.classifier.parameters(), 'lr': 1e-4},
]
if self.unfrozen_layers > 0:
layers = self.bert.encoder.layer[-self.unfrozen_layers:]
bert_params = []
for layer in layers:
for p in layer.parameters():
p.requires_grad = True
bert_params.append(p)
param_groups.append({'params': bert_params, 'lr': 2e-5})
self.optimizer = AdamW(param_groups, weight_decay=0.01)
self.scheduler = lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=2,
)
self.warmup_scheduler = None
def _get_warmup_scheduler(self, warmup_steps=1000):
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
return lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
def forward(self, input_ids, attention_mask, token_type_ids=None):
bert_out = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True,
)
hidden = torch.cat(bert_out.hidden_states[-2:], dim=-1)
feat = self.text_cnn(hidden)
return self.classifier(feat)
def validate(self, val_loader, device):
self.eval()
val_loss = 0
correct = 0
total = 0
all_preds = []
all_labels = []
with torch.no_grad():
pbar = tqdm(val_loader, desc='Validating')
for batch in pbar:
ids = batch['input_ids'].to(device)
mask = batch['attention_mask'].to(device)
types = batch['token_type_ids'].to(device)
labels = batch['label'].to(device)
logits = self(ids, mask, types)
loss = self.criterion(logits, labels)
val_loss += loss.item()
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
all_preds.extend(preds.cpu().tolist())
all_labels.extend(labels.cpu().tolist())
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
torch.cuda.empty_cache()
def train_model(self, train_loader, val_loader,
num_epochs=3, device='cpu',
save_path=None,
logdir=None,
validate_every=100,
warmup_steps=1000,
early_stop_patience=3):
self.to(device)
for param in self.bert.parameters():
param.requires_grad = False
best_val_loss = float('inf')
global_step = 0
epochs_no_improve = 0
best_model_state = None
if save_path is None:
save_path = f'output/{self.name}.pth'
if logdir is None:
logdir = f'runs/{self.name}'
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
total_loss = 0
correct = 0
total = 0
self.warmup_scheduler = self._get_warmup_scheduler(warmup_steps)
if epoch == 2:
print("Unfreezing 4 layers of BERT")
self.unfrozen_layers = 4
self._rebuild_optimizer()
pbar = tqdm(train_loader, desc='Training')
for batch in pbar:
ids = batch['input_ids'].to(device)
mask = batch['attention_mask'].to(device)
types = batch['token_type_ids'].to(device)
labels = batch['label'].to(device)
logits = self(ids, mask, types)
loss = self.criterion(logits, labels)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
self.optimizer.step()
if global_step < warmup_steps:
self.warmup_scheduler.step()
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
acc = correct / total
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{acc:.4f}'})
global_step += 1
if global_step % validate_every == 0:
torch.cuda.empty_cache()
self.eval()
with torch.no_grad():
metrics = self.validate(val_loader, device)
val_loss, val_acc = metrics['loss'], metrics['acc']
self.scheduler.step(val_loss)
print(f"\n[Step {global_step}] Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
print(metrics['report'])
report_text = metrics['report']
conf_mat = metrics['confusion_matrix']
print(report_text)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = self.state_dict()
epochs_no_improve = 0
torch.save(best_model_state, save_path)
print(f"Saved best model (step {global_step}) with loss {best_val_loss:.4f}")
else:
epochs_no_improve += 1
print(f"No improvement for {epochs_no_improve} checks")
if epochs_no_improve >= early_stop_patience:
print(f"Early stopping triggered at step {global_step}!")
self.load_state_dict(best_model_state)
return
self.train()
def predict(self, texts, device='cpu'):
"""Used for inference. Predicts the class of the input text.
Args:
texts (str or list of str): The input text(s) to classify, pass str.
- If a list is passed, the model will classify each text in the list as batch.
- If a single string is passed, the model will classify the text as a single instance.
- If a list of list is passed, the model will treate the first element as detected text and the second element as the context text.
device (str): The device to run the model on ('cpu', 'cuda', or 'mps'). If None, it will use the available device.
max_length (int): The maximum length of the input text.
Returns:
list: A list of dictionaries containing the prediction and probabilities for each input text.
Each dictionary contains:
- 'text': The input text.
- 'prediction': The predicted class (0 or 1).
- 'probabilities': The probabilities for each class.
"""
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
self.eval()
self.to(device)
if isinstance(texts, str):
texts = [texts]
encoded_inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
elif isinstance(texts, list) and all(isinstance(item, list) for item in texts):
encoded_inputs = self.tokenizer(
[item[0] for item in texts],
[item[1] for item in texts],
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
elif isinstance(texts, list) and all(isinstance(item, str) for item in texts):
encoded_inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
else:
raise ValueError("Invalid input type. Expected str or list of str.")
input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']
token_type_ids = encoded_inputs.get('token_type_ids', None)
with torch.no_grad():
logits = self(input_ids, attention_mask, token_type_ids)
probs = torch.softmax(logits, dim=-1)
preds = torch.argmax(probs, dim=-1)
results = []
for i, text in enumerate(texts):
results.append({
'text': text,
'prediction': preds[i].item(),
'probabilities': probs[i].cpu().tolist()
})
return results